From fca79d553644b4f61c4789ca78c9469181b329d9 Mon Sep 17 00:00:00 2001 From: lucas_miranda <lucasmiranda42@gmail.com> Date: Tue, 6 Apr 2021 17:34:20 +0200 Subject: [PATCH] Added _scaler method to table_dict, to retrieve the scaler used to preprocess data --- deepof/data.py | 9 +- .../deepof_data_exploration.ipynb | 52 +++-- .../deepof_model_evaluation.ipynb | 205 ++++++++++++++++-- 3 files changed, 222 insertions(+), 44 deletions(-) diff --git a/deepof/data.py b/deepof/data.py index cbea413d..2cbdd423 100644 --- a/deepof/data.py +++ b/deepof/data.py @@ -997,6 +997,7 @@ class table_dict(dict): self._arena_dims = arena_dims self._propagate_labels = propagate_labels self._propagate_annotations = propagate_annotations + self._scaler = None def filter_videos(self, keys: list) -> Table_dict: """Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example, @@ -1156,15 +1157,15 @@ class table_dict(dict): print("Scaling data...") if scale == "standard": - scaler = StandardScaler() + self._scaler = StandardScaler() elif scale == "minmax": - scaler = MinMaxScaler() + self._scaler = MinMaxScaler() else: raise ValueError( "Invalid scaler. Select one of standard, minmax or None" ) # pragma: no cover - X_train = scaler.fit_transform( + X_train = self._scaler.fit_transform( X_train.reshape(-1, X_train.shape[-1]) ).reshape(X_train.shape) @@ -1173,7 +1174,7 @@ class table_dict(dict): assert np.allclose(np.nan_to_num(np.std(X_train), nan=1), 1) if test_videos: - X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape( + X_test = self._scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape( X_test.shape ) diff --git a/supplementary_notebooks/deepof_data_exploration.ipynb b/supplementary_notebooks/deepof_data_exploration.ipynb index 3fd181a0..dd39d667 100644 --- a/supplementary_notebooks/deepof_data_exploration.ipynb +++ b/supplementary_notebooks/deepof_data_exploration.ipynb @@ -10,6 +10,16 @@ "os.chdir(os.path.dirname(\"../\"))" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -31,17 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -67,7 +67,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "exclude_bodyparts = tuple([\"\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": { "scrolled": true }, @@ -92,6 +101,7 @@ " path=\"../../Desktop/deepoftesttemp/\",\n", " arena_dims=[380],\n", " arena_detection=\"rule-based\",\n", + " exclude_bodyparts=exclude_bodyparts,\n", " interpolate_outliers=True,\n", ").run()" ] @@ -311,6 +321,8 @@ "metadata": {}, "outputs": [], "source": [ + "# Auxiliary animation functions\n", + "\n", "def plot_mouse_graph(instant_x, instant_y, ax, edges):\n", " \"\"\"Generates a graph plot of the mouse\"\"\"\n", " plots = []\n", @@ -336,13 +348,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f5a8a2dbd0074f33af0f87dfcc279a28", + "model_id": "5ca717ea26e748c3aee0c6b0112a044d", "version_major": 2, "version_minor": 0 }, @@ -364,6 +376,10 @@ " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", "\n", " edges = deepof.utils.connect_mouse_topview()\n", + " \n", + " for bpart in exclude_bodyparts:\n", + " edges.remove_node(bpart)\n", + " \n", " for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n", " edges.remove_edge(\"Center\", limb)\n", " edges = edges.edges()\n", @@ -375,8 +391,8 @@ " data[\"Center\", \"x\"] = 0\n", " data[\"Center\", \"y\"] = 0\n", "\n", - " init_x = data.xs(\"x\", level=\"coords\", axis=1, drop_level=False).iloc[0, :]\n", - " init_y = data.xs(\"y\", level=\"coords\", axis=1, drop_level=False).iloc[0, :]\n", + " init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", + " init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", "\n", " plots = plot_mouse_graph(init_x, init_y, ax, edges)\n", " scatter = ax.scatter(x=np.array(init_x), y=np.array(init_y), color=\"#006699\",)\n", @@ -384,8 +400,8 @@ " # Update data in main plot\n", " def animation_frame(i):\n", " # Update scatter plot\n", - " x = data.xs(\"x\", level=\"coords\", axis=1, drop_level=False).iloc[i, :]\n", - " y = data.xs(\"y\", level=\"coords\", axis=1, drop_level=False).iloc[i, :]\n", + " x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", + " y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", "\n", " scatter.set_offsets(np.c_[np.array(x), np.array(y)])\n", " update_mouse_graph(x, y, plots, edges)\n", diff --git a/supplementary_notebooks/deepof_model_evaluation.ipynb b/supplementary_notebooks/deepof_model_evaluation.ipynb index 0b45cab1..3407356d 100644 --- a/supplementary_notebooks/deepof_model_evaluation.ipynb +++ b/supplementary_notebooks/deepof_model_evaluation.ipynb @@ -10,6 +10,16 @@ "%autoreload 2" ] }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -71,25 +81,27 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepoftesttemp\")\n", - "trained_network = os.path.join(\"..\", \"..\", \"Desktop\")" + "trained_network = os.path.join(\"..\", \"..\", \"Desktop\")\n", + "exclude_bodyparts = [\"Tail_1\", \"Tail_2\", \"Tail_tip\", \"Tail_base\"]\n", + "window_size = 24" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 297 ms, sys: 23.9 ms, total: 321 ms\n", - "Wall time: 270 ms\n" + "CPU times: user 280 ms, sys: 22.3 ms, total: 303 ms\n", + "Wall time: 250 ms\n" ] } ], @@ -98,14 +110,14 @@ "proj = deepof.data.project(\n", " path=path,\n", " smooth_alpha=0.99,\n", - " exclude_bodyparts=[\"Tail_1\", \"Tail_2\", \"Tail_tip\", \"Tail_base\"],\n", + " exclude_bodyparts=exclude_bodyparts,\n", " arena_dims=[380],\n", ")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 55, "metadata": {}, "outputs": [ { @@ -120,8 +132,8 @@ "Computing angles...\n", "Done!\n", "deepof analysis of 2 videos\n", - "CPU times: user 2.63 s, sys: 86.3 ms, total: 2.72 s\n", - "Wall time: 683 ms\n" + "CPU times: user 2.63 s, sys: 85.8 ms, total: 2.71 s\n", + "Wall time: 634 ms\n" ] } ], @@ -140,17 +152,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n", - "preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=24)" + "preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=window_size)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -164,9 +176,23 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 58, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ValueError", + "evalue": "You are trying to load a weight file containing 15 layers into a model with 14 layers.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-58-ea431bf97d05>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m gmvaep.load_weights(\n\u001b[1;32m 11\u001b[0m os.path.join(\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtrained_network\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrained_network\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"h5\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m )\n\u001b[1;32m 14\u001b[0m )\n", + "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mload_weights\u001b[0;34m(self, filepath, by_name, skip_mismatch, options)\u001b[0m\n\u001b[1;32m 2232\u001b[0m f, self.layers, skip_mismatch=skip_mismatch)\n\u001b[1;32m 2233\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2234\u001b[0;31m \u001b[0mhdf5_format\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights_from_hdf5_group\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2235\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2236\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_updated_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/saving/hdf5_format.py\u001b[0m in \u001b[0;36mload_weights_from_hdf5_group\u001b[0;34m(f, layers)\u001b[0m\n\u001b[1;32m 686\u001b[0m \u001b[0;34m'containing '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_names\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[0;34m' layers into a model with '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_layers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 688\u001b[0;31m ' layers.')\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 690\u001b[0m \u001b[0;31m# We batch weight value assignments in a single backend call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: You are trying to load a weight file containing 15 layers into a model with 14 layers." + ] + } + ], "source": [ "encoder, decoder, grouper, gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(\n", " loss=loss,\n", @@ -186,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 59, "metadata": { "scrolled": true }, @@ -201,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -237,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -255,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -265,12 +291,12 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "# Rescale reconstructions\n", - "rescaled_reconstructions = scaler.transform(\n", + "rescaled_reconstructions = scaler.inverse_transform(\n", " reconstrs.reshape(reconstrs.shape[0] * reconstrs.shape[1], reconstrs.shape[2])\n", ")\n", "rescaled_reconstructions = rescaled_reconstructions.reshape(reconstrs.shape)" @@ -278,11 +304,139 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 98, "metadata": {}, "outputs": [], "source": [ - "# Display a video with the original data superimposed with the reconstructions\n" + "# Auxiliary animation functions\n", + "\n", + "\n", + "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n", + " \"\"\"Generates a graph plot of the mouse\"\"\"\n", + " plots = []\n", + " rec_plots = []\n", + " for edge in edges:\n", + " (temp_plot,) = ax.plot(\n", + " [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n", + " [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n", + " color=\"#006699\",\n", + " )\n", + " (temp_rec_plot,) = ax.plot(\n", + " [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n", + " [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n", + " color=\"#006699\",\n", + " )\n", + " plots.append(temp_plot)\n", + " rec_plots.append(temp_rec_plot)\n", + " return plots\n", + "\n", + "\n", + "def update_mouse_graph(x, y, plots, edges):\n", + " \"\"\"Updates the graph plot to enable animation\"\"\"\n", + "\n", + " for plot, edge in zip(plots, edges):\n", + " plot.set_data(\n", + " [float(x[edge[0]]), float(x[edge[1]])],\n", + " [float(y[edge[0]]), float(y[edge[1]])],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "db7a005dac9b4673bc2c91dac50c7db0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatSlider(value=7500.0, description='time_slider', max=15000.0, step=500.0), IntSlider…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Display a video with the original data superimposed with the reconstructions\n", + "\n", + "random_exp = np.random.choice(list(coords.keys()), 1)[0]\n", + "\n", + "\n", + "@interact(time_slider=(0.0, 15000, 500), length_slider=(0, 1000, 100))\n", + "def animate_mice_across_time(time_slider, length_slider):\n", + "\n", + " fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + "\n", + " edges = deepof.utils.connect_mouse_topview()\n", + "\n", + " for bpart in exclude_bodyparts:\n", + " edges.remove_node(bpart)\n", + "\n", + " for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n", + " edges.remove_edge(\"Center\", limb)\n", + "\n", + " edges = edges.edges()\n", + "\n", + " data = coords[random_exp].loc[time_slider : time_slider + length_slider - 1, :]\n", + " data_rec = gmvaep.predict(\n", + " coords.filter_videos([random_exp]).preprocess(\n", + " test_videos=0, window_step=5, window_size=window_size\n", + " )[0]\n", + " )\n", + " data_rec = pd.DataFrame(scaler.inverse_transform(data_rec[:, 24 // 2, :]))\n", + " data_rec.columns = data.columns\n", + "\n", + " data[\"Center\", \"x\"] = 0\n", + " data[\"Center\", \"y\"] = 0\n", + " data_rec[\"Center\", \"x\"] = 0\n", + " data_rec[\"Center\", \"y\"] = 0\n", + "\n", + " init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", + " init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", + " init_rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n", + " init_rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n", + "\n", + " plots = plot_mouse_graph(init_x, init_y, init_rec_x, init_rec_y, ax, edges)\n", + " scatter = ax.scatter(x=np.array(init_x), y=np.array(init_y), color=\"#006699\",)\n", + " rec_scatter = ax.scatter(\n", + " x=np.array(init_rec_x), y=np.array(init_rec_y), color=\"#006699\",\n", + " )\n", + "\n", + " # Update data in main plot\n", + " def animation_frame(i):\n", + " # Update scatter plot\n", + " x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", + " y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", + " rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n", + " rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n", + " \n", + " scatter.set_offsets(np.c_[np.array(x), np.array(y)])\n", + " scatter.set_offsets(np.c_[np.array(rec_x), np.array(rec_y)])\n", + " update_mouse_graph(x, y, plots, edges)\n", + "\n", + " return scatter\n", + "\n", + " animation = FuncAnimation(\n", + " fig, func=animation_frame, frames=length_slider, interval=100,\n", + " )\n", + "\n", + " ax.set_title(\"Positions across time for centered data\")\n", + " ax.set_ylim(-100, 60)\n", + " ax.set_xlim(-60, 60)\n", + " ax.set_xlabel(\"x\")\n", + " ax.set_ylabel(\"y\")\n", + "\n", + " video = animation.to_html5_video()\n", + " html = display.HTML(video)\n", + " display.display(html)\n", + " plt.close()" ] }, { @@ -299,6 +453,13 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, -- GitLab