{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# deepOF model evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a dataset and a trained model, this notebook allows the user to \n", "\n", "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n", "* Visualize reconstruction quality for a given model\n", "* Visualize a static latent space\n", "* Visualize trajectories on the latent space for a given video\n", "* sample from the latent space distributions and generate video clips showcasing generated data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.chdir(os.path.dirname(\"../\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import deepof.data\n", "import deepof.utils\n", "import numpy as np\n", "import pandas as pd\n", "import re\n", "import tensorflow as tf\n", "from collections import Counter\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from sklearn.manifold import TSNE\n", "from sklearn.decomposition import PCA\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "import umap\n", "\n", "from ipywidgets import interactive, interact, HBox, Layout, VBox\n", "from IPython import display\n", "from matplotlib.animation import FuncAnimation\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from ipywidgets import interact" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Define and run project" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepoftesttemp\")\n", "trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"trained_weights\")\n", "exclude_bodyparts = tuple([\"\"])\n", "window_size = 24" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "proj = deepof.data.project(\n", " path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "proj = proj.run(verbose=True)\n", "print(proj)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Load pretrained deepof model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "data_prep = coords.preprocess(test_videos=0, window_step=1, window_size=window_size, shuffle=True)[\n", " 0\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[i for i in os.listdir(trained_network) if i.endswith(\"h5\")]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "deepof_weights = [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][-1]\n", "deepof_weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set model parameters\n", "encoding = int(re.findall(\"encoding=(\\d+)_\", deepof_weights)[0])\n", "k = int(re.findall(\"k=(\\d+)_\", deepof_weights)[0])\n", "loss = re.findall(\"loss=(.+?)_\", deepof_weights)[0]\n", "pheno = 0\n", "predictor = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "(\n", " encode_to_vector,\n", " decoder,\n", " grouper,\n", " gmvaep,\n", " prior,\n", " posterior,\n", ") = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", " number_of_components=k,\n", " compile_model=True,\n", " encoding=encoding,\n", " predictor=predictor,\n", " phenotype_prediction=pheno,\n", ").build(\n", " data_prep.shape\n", ")\n", "\n", "gmvaep.load_weights(os.path.join(trained_network, deepof_weights))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Uncomment to see model summaries\n", "# encoder.summary()\n", "# decoder.summary()\n", "# grouper.summary()\n", "# gmvaep.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Uncomment to plot model structure\n", "def plot_model(model, name):\n", " tf.keras.utils.plot_model(\n", " model,\n", " to_file=os.path.join(\n", " path,\n", " \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n", " ),\n", " show_shapes=True,\n", " show_dtype=False,\n", " show_layer_names=True,\n", " rankdir=\"TB\",\n", " expand_nested=True,\n", " dpi=200,\n", " )\n", "\n", "\n", "# plot_model(encoder, \"encoder\")\n", "# plot_model(decoder, \"decoder\")\n", "# plot_model(grouper, \"grouper\")\n", "# plot_model(gmvaep, \"gmvaep\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Auxiliary animation functions\n", "\n", "\n", "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", " rec_plots = []\n", " for edge in edges:\n", " (temp_plot,) = ax.plot(\n", " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", " color=\"#006699\",\n", " linewidth=2.0,\n", " )\n", " (temp_rec_plot,) = ax.plot(\n", " [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n", " [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n", " color=\"red\",\n", " linewidth=2.0,\n", " )\n", " plots.append(temp_plot)\n", " rec_plots.append(temp_rec_plot)\n", " return plots, rec_plots\n", "\n", "\n", "def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):\n", " \"\"\"Updates the graph plot to enable animation\"\"\"\n", "\n", " for plot, edge in zip(plots, edges):\n", " plot.set_data(\n", " [float(x[edge[0]]), float(x[edge[1]])],\n", " [float(y[edge[0]]), float(y[edge[1]])],\n", " )\n", " for plot, edge in zip(rec_plots, edges):\n", " plot.set_data(\n", " [float(rec_x[edge[0]]), float(rec_x[edge[1]])],\n", " [float(rec_y[edge[0]]), float(rec_y[edge[1]])],\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Display a video with the original data superimposed with the reconstructions\n", "\n", "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", "random_exp = np.random.choice(list(coords.keys()), 1)[0]\n", "print(random_exp)\n", "\n", "\n", "def animate_mice_across_time(random_exp):\n", "\n", " # Define canvas\n", " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", "\n", " # Retrieve body graph\n", " edges = deepof.utils.connect_mouse_topview()\n", "\n", " for bpart in exclude_bodyparts:\n", " if bpart:\n", " edges.remove_node(bpart)\n", "\n", " for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n", " edges.remove_edge(\"Center\", limb)\n", " if (\"Tail_base\", limb) in edges.edges():\n", " edges.remove_edge(\"Tail_base\", limb)\n", "\n", " edges = edges.edges()\n", "\n", " # Compute observed and predicted data to plot\n", " data = coords[random_exp]\n", " coords_rec = coords.filter_videos([random_exp])\n", " data_prep = coords_rec.preprocess(\n", " test_videos=0, window_step=1, window_size=window_size, shuffle=False\n", " )[0]\n", "\n", " data_rec = gmvaep.predict(data_prep)\n", " data_rec = pd.DataFrame(coords_rec._scaler.inverse_transform(data_rec[:, 6, :]))\n", " data_rec.columns = data.columns\n", " data = pd.DataFrame(coords_rec._scaler.inverse_transform(data_prep[:, 6, :]))\n", " data.columns = data_rec.columns\n", "\n", " # Add Central coordinate, lost during alignment\n", " data[\"Center\", \"x\"] = 0\n", " data[\"Center\", \"y\"] = 0\n", " data_rec[\"Center\", \"x\"] = 0\n", " data_rec[\"Center\", \"y\"] = 0\n", "\n", " # Plot!\n", " init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", "\n", " plots, rec_plots = plot_mouse_graph(\n", " init_x, init_y, init_rec_x, init_rec_y, ax, edges\n", " )\n", " scatter = ax.scatter(\n", " x=np.array(init_x), y=np.array(init_y), color=\"#006699\", label=\"Original\"\n", " )\n", " rec_scatter = ax.scatter(\n", " x=np.array(init_rec_x),\n", " y=np.array(init_rec_y),\n", " color=\"red\",\n", " label=\"Reconstruction\",\n", " )\n", "\n", " # Update data in main plot\n", " def animation_frame(i):\n", " # Update scatter plot\n", " x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", "\n", " scatter.set_offsets(np.c_[np.array(x), np.array(y)])\n", " rec_scatter.set_offsets(np.c_[np.array(rec_x), np.array(rec_y)])\n", " update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges)\n", "\n", " return scatter\n", "\n", " animation = FuncAnimation(fig, func=animation_frame, frames=250, interval=50,)\n", "\n", " ax.set_title(\"Original versus reconstructed data\")\n", " ax.set_ylim(-100, 60)\n", " ax.set_xlim(-60, 60)\n", " ax.set_xlabel(\"x\")\n", " ax.set_ylabel(\"y\")\n", " plt.legend()\n", "\n", " video = animation.to_html5_video()\n", " html = display.HTML(video)\n", " display.display(html)\n", " plt.close()\n", "\n", "\n", "animate_mice_across_time(random_exp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Evaluate latent space (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get encodings and groupings for the same random video as above\n", "data_prep = coords.preprocess(\n", " test_videos=0, window_step=1, window_size=window_size, shuffle=False\n", ")[0]\n", "\n", "encodings = encode_to_vector.predict(data_prep)\n", "groupings = grouper.predict(data_prep)\n", "hard_groups = np.argmax(groupings, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@interact(minimum_confidence=(0.0, 1.0, 0.01))\n", "def plot_cluster_population(minimum_confidence):\n", " plt.figure(figsize=(12, 8))\n", "\n", " groups = hard_groups[np.max(groupings, axis=1) > minimum_confidence].flatten()\n", " groups = np.concatenate([groups, np.arange(25)])\n", " sns.countplot(groups)\n", " plt.xlabel(\"Cluster\")\n", " plt.title(\"Training instances per cluster\")\n", " plt.ylim(0, hard_groups.shape[0] * 1.1)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The slider in the figure above lets you set the minimum confidence the model may yield when assigning a training instance to a cluster in order to be visualized." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Plot real data in the latent space\n", "\n", "samples = np.random.choice(range(encodings.shape[0]), 10000)\n", "sample_enc = encodings[samples, :]\n", "sample_grp = groupings[samples, :]\n", "sample_hgr = hard_groups[samples]\n", "k = sample_grp.shape[1]\n", "\n", "umap_reducer = umap.UMAP(n_components=2)\n", "pca_reducer = PCA(n_components=2)\n", "tsne_reducer = TSNE(n_components=2)\n", "lda_reducer = LinearDiscriminantAnalysis(n_components=2)\n", "\n", "umap_enc = umap_reducer.fit_transform(sample_enc)\n", "pca_enc = pca_reducer.fit_transform(sample_enc)\n", "tsne_enc = tsne_reducer.fit_transform(sample_enc)\n", "try:\n", " lda_enc = lda_reducer.fit_transform(sample_enc, sample_hgr)\n", "except ValueError:\n", " warnings.warn(\n", " \"Only one class found. Can't use LDA\", DeprecationWarning, stacklevel=2\n", " )\n", "\n", "\n", "@interact(\n", " minimum_confidence=(0.0, 0.99, 0.01),\n", " dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"],\n", " highlight_clusters=False,\n", " selected_cluster=(0, k-1),\n", ")\n", "def plot_static_latent_space(\n", " minimum_confidence, dim_red, highlight_clusters, selected_cluster\n", "):\n", "\n", " global sample_enc, sample_grp, sample_hgr\n", "\n", " if dim_red == \"umap\":\n", " enc = umap_enc\n", " elif dim_red == \"LDA\":\n", " enc = lda_enc\n", " elif dim_red == \"PCA\":\n", " enc = pca_enc\n", " else:\n", " enc = tsne_enc\n", "\n", " enc = enc[np.max(sample_grp, axis=1) > minimum_confidence]\n", " hgr = sample_hgr[np.max(sample_grp, axis=1) > minimum_confidence].flatten()\n", " grp = sample_grp[np.max(sample_grp, axis=1) > minimum_confidence]\n", "\n", " plt.figure(figsize=(12, 8))\n", "\n", " sns.scatterplot(\n", " x=enc[:, 0],\n", " y=enc[:, 1],\n", " hue=hgr,\n", " size=np.max(grp, axis=1),\n", " sizes=(1, 100),\n", " palette=sns.color_palette(\"husl\", len(set(hgr))),\n", " )\n", " \n", " if highlight_clusters:\n", " sns.kdeplot(\n", " enc[hgr == selected_cluster, 0],\n", " enc[hgr == selected_cluster, 1],\n", " color=\"red\",\n", " )\n", " \n", " plt.xlabel(\"{} 1\".format(dim_red))\n", " plt.ylabel(\"{} 2\".format(dim_red))\n", " plt.suptitle(\"Static view of trained latent space\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_mouse_graph(instant_x, instant_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", " for edge in edges:\n", " (temp_plot,) = ax.plot(\n", " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", " color=\"#006699\",\n", " linewidth=2.0,\n", " )\n", " plots.append(temp_plot)\n", " return plots\n", "\n", "\n", "def update_mouse_graph(x, y, plots, edges):\n", " \"\"\"Updates the graph plot to enable animation\"\"\"\n", "\n", " for plot, edge in zip(plots, edges):\n", " plot.set_data(\n", " [float(x[edge[0]]), float(x[edge[1]])],\n", " [float(y[edge[0]]), float(y[edge[1]])],\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Plot trajectory of a video in latent space\n", "\n", "samples = np.random.choice(range(encodings.shape[0]), 10000)\n", "sample_enc = encodings[samples, :]\n", "sample_grp = groupings[samples, :]\n", "sample_hgr = hard_groups[samples]\n", "k = sample_grp.shape[1]\n", "\n", "umap_reducer = umap.UMAP(n_components=2)\n", "pca_reducer = PCA(n_components=2)\n", "tsne_reducer = TSNE(n_components=2)\n", "lda_reducer = LinearDiscriminantAnalysis(n_components=2)\n", "\n", "umap_enc = umap_reducer.fit_transform(sample_enc)\n", "pca_enc = pca_reducer.fit_transform(sample_enc)\n", "tsne_enc = tsne_reducer.fit_transform(sample_enc)\n", "try:\n", " lda_enc = lda_reducer.fit_transform(sample_enc, sample_hgr)\n", "except ValueError:\n", " warnings.warn(\n", " \"Only one class found. Can't use LDA\", DeprecationWarning, stacklevel=2\n", " )\n", "\n", "\n", "@interact(\n", " trajectory=(100, 500), trace=False, dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"],\n", ")\n", "def plot_dynamic_latent_pace(trajectory, trace, dim_red):\n", "\n", " global sample_enc, sample_grp, sample_hgr\n", "\n", " if dim_red == \"umap\":\n", " enc = umap_enc\n", " elif dim_red == \"LDA\":\n", " enc = lda_enc\n", " elif dim_red == \"PCA\":\n", " enc = pca_enc\n", " else:\n", " enc = tsne_enc\n", "\n", " traj_enc = enc[:trajectory, :]\n", " traj_grp = enc[:trajectory, :]\n", " traj_hgr = enc[:trajectory]\n", "\n", " # Define two figures arranged horizontally\n", " fig, (ax, ax2) = plt.subplots(\n", " 1, 2, figsize=(12, 8), gridspec_kw={\"width_ratios\": [3, 1.5]}\n", " )\n", "\n", " # Plot the animated embedding trajectory on the left\n", " sns.scatterplot(\n", " x=enc[:, 0],\n", " y=enc[:, 1],\n", " hue=sample_hgr,\n", " size=np.max(sample_grp, axis=1),\n", " sizes=(1, 100),\n", " palette=sns.color_palette(\"husl\", len(set(sample_hgr))),\n", " ax=ax,\n", " )\n", "\n", " traj_init = traj_enc[0, :]\n", " scatter = ax.scatter(\n", " x=[traj_init[0]], y=[traj_init[1]], s=100, color=\"red\", edgecolor=\"black\"\n", " )\n", " (lineplt,) = ax.plot([traj_init[0]], [traj_init[1]], color=\"red\", linewidth=2.0)\n", " tracking_line_x = []\n", " tracking_line_y = []\n", "\n", " # Plot the initial data (before feeding it to the encoder) on the right\n", " edges = deepof.utils.connect_mouse_topview()\n", "\n", " for bpart in exclude_bodyparts:\n", " if bpart:\n", " edges.remove_node(bpart)\n", "\n", " for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n", " edges.remove_edge(\"Center\", limb)\n", " if (\"Tail_base\", limb) in list(edges.edges()):\n", " edges.remove_edge(\"Tail_base\", limb)\n", "\n", " edges = edges.edges()\n", "\n", " inv_coords = coords._scaler.inverse_transform(data_prep)[:, window_size // 2, :]\n", " data = pd.DataFrame(inv_coords, columns=coords[random_exp].columns)\n", "\n", " data[\"Center\", \"x\"] = 0\n", " data[\"Center\", \"y\"] = 0\n", "\n", " init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", "\n", " plots = plot_mouse_graph(init_x, init_y, ax2, edges)\n", " track = ax2.scatter(x=np.array(init_x), y=np.array(init_y), color=\"#006699\",)\n", "\n", " # Update data in both plots\n", " def animation_frame(i):\n", " # Update scatter plot\n", " offset = traj_enc[i, :]\n", "\n", " prev_t = scatter.get_offsets()[0]\n", "\n", " if trace:\n", " tracking_line_x.append([prev_t[0], offset[0]])\n", " tracking_line_y.append([prev_t[1], offset[1]])\n", " lineplt.set_xdata(tracking_line_x)\n", " lineplt.set_ydata(tracking_line_y)\n", "\n", " scatter.set_offsets(np.c_[np.array(offset[0]), np.array(offset[1])])\n", "\n", " x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " track.set_offsets(np.c_[np.array(x), np.array(y)])\n", " update_mouse_graph(x, y, plots, edges)\n", "\n", " return scatter\n", "\n", " animation = FuncAnimation(\n", " fig, func=animation_frame, frames=trajectory, interval=75,\n", " )\n", "\n", " ax.set_xlabel(\"{} 1\".format(dim_red))\n", " ax.set_ylabel(\"{} 2\".format(dim_red))\n", "\n", " ax2.set_xlabel(\"x\")\n", " ax2.set_xlabel(\"y\")\n", " ax2.set_ylim(-90, 60)\n", " ax2.set_xlim(-60, 60)\n", "\n", " plt.tight_layout()\n", "\n", " video = animation.to_html5_video()\n", " html = display.HTML(video)\n", " display.display(html)\n", " plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Sample from latent space (to be incorporated into deepof.evaluate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get prior distribution\n", "\n", "means = prior.components_distribution.mean().numpy()\n", "stddevs = prior.components_distribution.stddev().numpy()\n", "\n", "samples = []\n", "for i in range(means.shape[0]):\n", " samples.append(\n", " np.random.normal(means[i, :], stddevs[i, :], size=(500, means.shape[1]))\n", " )\n", "samples = np.concatenate(samples)\n", "decodings = decoder.predict(samples)\n", "\n", "umap_reducer = umap.UMAP(n_components=2)\n", "pca_reducer = PCA(n_components=2)\n", "tsne_reducer = TSNE(n_components=2)\n", "lda_reducer = LinearDiscriminantAnalysis(n_components=2)\n", "\n", "umap_enc = umap_reducer.fit_transform(samples)\n", "pca_enc = pca_reducer.fit_transform(samples)\n", "tsne_enc = tsne_reducer.fit_transform(samples)\n", "lda_enc = lda_reducer.fit_transform(samples, np.repeat(range(means.shape[0]), 500))\n", "\n", "\n", "@interact(dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"], selected_cluster=(1, k))\n", "def sample_from_prior(dim_red, selected_cluster):\n", "\n", " if dim_red == \"umap\":\n", " sample_enc = umap_enc\n", " elif dim_red == \"LDA\":\n", " sample_enc = lda_enc\n", " elif dim_red == \"PCA\":\n", " sample_enc = pca_enc\n", " else:\n", " sample_enc = tsne_enc\n", "\n", " fig, (ax, ax2) = plt.subplots(\n", " 1, 2, figsize=(12, 8), gridspec_kw={\"width_ratios\": [3, 1.5]}\n", " )\n", "\n", " hue = np.repeat(range(means.shape[0]), 500)\n", "\n", " # Plot the animated embedding trajectory on the left\n", " sns.scatterplot(\n", " x=sample_enc[:, 0],\n", " y=sample_enc[:, 1],\n", " hue=hue,\n", " palette=sns.color_palette(\"husl\", k),\n", " ax=ax,\n", " )\n", "\n", " sns.kdeplot(\n", " sample_enc[hue == selected_cluster, 0],\n", " sample_enc[hue == selected_cluster, 1],\n", " color=\"red\",\n", " ax=ax,\n", " )\n", "\n", " # Get reconstructions from samples of a given cluster\n", " decs = decodings[hue == selected_cluster][np.random.randint(0, 500, 5)]\n", "\n", " # Plot the initial data (before feeding it to the encoder) on the right\n", " edges = deepof.utils.connect_mouse_topview()\n", "\n", " for bpart in exclude_bodyparts:\n", " if bpart:\n", " edges.remove_node(bpart)\n", "\n", " for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n", " edges.remove_edge(\"Center\", limb)\n", " if (\"Tail_base\", limb) in list(edges.edges()):\n", " edges.remove_edge(\"Tail_base\", limb)\n", "\n", " edges = edges.edges()\n", "\n", " inv_coords = coords._scaler.inverse_transform(decs).reshape(\n", " decs.shape[0] * decs.shape[1], decs.shape[2]\n", " )\n", " data = pd.DataFrame(inv_coords, columns=coords[random_exp].columns)\n", "\n", " data[\"Center\", \"x\"] = 0\n", " data[\"Center\", \"y\"] = 0\n", "\n", " init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", " init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", "\n", " plots = plot_mouse_graph(init_x, init_y, ax2, edges)\n", " track = ax2.scatter(x=np.array(init_x), y=np.array(init_y), color=\"#006699\",)\n", "\n", " # Update data in both plots\n", " def animation_frame(i):\n", " # Update scatter plot\n", "\n", " x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", " track.set_offsets(np.c_[np.array(x), np.array(y)])\n", " update_mouse_graph(x, y, plots, edges)\n", "\n", " animation = FuncAnimation(\n", " fig, func=animation_frame, frames=5 * window_size, interval=50,\n", " )\n", "\n", " # Plot samples as video on the right\n", "\n", " ax.set_xlabel(\"{} 1\".format(dim_red))\n", " ax.set_ylabel(\"{} 2\".format(dim_red))\n", " ax.get_legend().remove()\n", "\n", " ax2.set_xlabel(\"x\")\n", " ax2.set_xlabel(\"y\")\n", " ax2.set_ylim(-90, 60)\n", " ax2.set_xlim(-60, 60)\n", "\n", " plt.tight_layout()\n", " \n", " video = animation.to_html5_video()\n", " html = display.HTML(video)\n", " display.display(html)\n", " plt.close()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }