Commit aadc4d27 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added notebook for dataset and model evaluation

parent cbbf9380
Pipeline #97696 passed with stages
in 17 minutes and 6 seconds
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:markdown id: tags:
### Latent space sampling
# deepOF data exploration
%% Cell type:markdown id: tags:
Given a dataset and a trained model, this notebook allows the user to sample from the latent space distributions and generate video clips showcasing the results
Given a dataset, this notebook allows the user to
* Load and process the dataset using deepof.data
* Visualize data quality with interactive plots
* Visualize training instances as multi-timepoint scatter plots with interactive configurations
* Visualize training instances as video clips with interactive configurations
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
import cv2
import deepof.data
import deepof.models
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import deepof.utils
import numpy as np
import pandas as pd
import re
import seaborn as sns
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import tensorflow as tf
import tqdm.notebook as tqdm
from datetime import datetime
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import umap
from ipywidgets import interact
from IPython import display
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import seaborn as sns
```
%% Cell type:markdown id: tags:
### 1. Define and run project
%% Cell type:code id: tags:
``` python
path = os.path.join("..", "..", "Desktop", "deepoftesttemp")
trained_network = os.path.join("..", "..", "Desktop")
```
%% Cell type:code id: tags:
``` python
%%time
# Use deepof to load a project
proj = deepof.data.project(
path=path,
smooth_alpha=0.99,
exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip", "Tail_base"],
path="../../Desktop/deepoftesttemp/",
arena_dims=[380],
)
arena_detection="rule-based",
interpolate_outliers=True,
).run()
```
%% Cell type:code id: tags:
%% Cell type:markdown id: tags:
``` python
%%time
proj = proj.run(verbose=True)
print(proj)
```
### 2. Inspect dataset quality
%% Cell type:code id: tags:
``` python
all_quality = pd.concat([tab for tab in proj.get_quality().values()])
......@@ -115,100 +100,206 @@
%%%% Output: display_data
%% Cell type:markdown id: tags:
### 2. Load pretrained deepof model
In the cell above, you see the percentage of labels per body part which have a quality lower than the selected value (0.50 by default) **before** preprocessing. The values are taken directly from DeepLabCut.
%% Cell type:markdown id: tags:
### 3. Get coordinates, distances and angles
%% Cell type:markdown id: tags:
And get speed, acceleration and jerk for each
%% Cell type:code id: tags:
``` python
coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=24)
# Get coordinates, speeds, accelerations and jerks for positions
position_coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
position_speeds = proj.get_coords(center="Center", speed=1)
position_accels = proj.get_coords(center="Center", speed=2)
position_jerks = proj.get_coords(center="Center", speed=3)
```
%% Cell type:code id: tags:
``` python
# Set model parameters
encoding=6
loss="ELBO"
k=25
pheno=0
predictor=0
# Get coordinates, speeds, accelerations and jerks for distances
distance_coords = proj.get_distances()
distance_speeds = proj.get_distances(speed=1)
distance_accels = proj.get_distances(speed=2)
distance_jerks = proj.get_distances(speed=3)
```
%% Cell type:code id: tags:
``` python
encoder, decoder, grouper, gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
compile_model=True,
encoding=encoding,
predictor=predictor,
phenotype_prediction=pheno,
).build(preprocessed_data.shape)[:4]
# Get coordinates, speeds, accelerations and jerks for angles
angle_coords = proj.get_angles()
angle_speeds = proj.get_angles(speed=1)
angle_accels = proj.get_angles(speed=2)
angle_jerks = proj.get_angles(speed=3)
```
gmvaep.load_weights(
os.path.join(
trained_network, [i for i in os.listdir(trained_network) if i.endswith("h5")][0]
%% Cell type:markdown id: tags:
### 4. Display training instances
%% Cell type:code id: tags:
``` python
random_exp = np.random.choice(list(position_coords.keys()), 1)[0]
@interact(time_slider=(0.0, 15000, 25), length_slider=(10, 100, 5))
def plot_mice_across_time(time_slider, length_slider):
plt.figure(figsize=(10, 10))
for bpart in position_coords[random_exp].columns.levels[0]:
if bpart != "Center":
sns.scatterplot(
data=position_coords[random_exp].loc[
time_slider : time_slider + length_slider - 1, bpart
],
x="x",
y="y",
label=bpart,
palette=sns.color_palette("tab10"),
)
plt.title("Positions across time for centered data")
plt.legend(
fontsize=15,
bbox_to_anchor=(1.5, 1),
title="Body part",
title_fontsize=18,
shadow=False,
facecolor="white",
)
)
plt.ylim(-100, 60)
plt.xlim(-60, 60)
plt.show()
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
The figure above is a multi time-point scatter plot. The time_slider allows you to scroll across the video, and the length_slider selects the number of time-points to include. The idea is to intuitively visualize the data that goes into a training instance for a given preprocessing setting.
%% Cell type:code id: tags:
``` python
# Uncomment to see model summaries
# encoder.summary()
# decoder.summary()
# grouper.summary()
# gmvaep.summary()
def plot_mouse_graph(instant_x, instant_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
for edge in edges:
(temp_plot,) = ax.plot(
[float(instant_x[edge[0]]), float(instant_x[edge[1]])],
[float(instant_y[edge[0]]), float(instant_y[edge[1]])],
color="#006699",
)
plots.append(temp_plot)
return plots
def update_mouse_graph(x, y, plots, edges):
"""Updates the graph plot to enable animation"""
for plot, edge in zip(plots, edges):
plot.set_data(
[float(x[edge[0]]), float(x[edge[1]])],
[float(y[edge[0]]), float(y[edge[1]])],
)
```
%% Cell type:code id: tags:
``` python
# Uncomment to plot model structure
def plot_model(model, name):
tf.keras.utils.plot_model(
model,
to_file=os.path.join(
path,
"deepof_{}_{}.png".format(name, datetime.now().strftime("%Y%m%d-%H%M%S")),
),
show_shapes=True,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=True,
dpi=200,
)
random_exp = np.random.choice(list(position_coords.keys()), 1)[0]
@interact(time_slider=(0.0, 15000, 25), length_slider=(10, 100, 5))
def animate_mice_across_time(time_slider, length_slider):
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
edges = deepof.utils.connect_mouse_topview()
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
edges = edges.edges()
data = position_coords[random_exp].loc[
time_slider : time_slider + length_slider - 1, :
]
data["Center", "x"] = 0
data["Center", "y"] = 0
init_x = data.xs("x", level="coords", axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level="coords", axis=1, drop_level=False).iloc[0, :]
plots = plot_mouse_graph(init_x, init_y, ax, edges)
scatter = ax.scatter(x=np.array(init_x), y=np.array(init_y), color="#006699",)
# Update data in main plot
def animation_frame(i):
# Update scatter plot
x = data.xs("x", level="coords", axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level="coords", axis=1, drop_level=False).iloc[i, :]
# plot_model(encoder, "encoder")
# plot_model(decoder, "decoder")
# plot_model(grouper, "grouper")
# plot_model(gmvaep, "gmvaep")
scatter.set_offsets(np.c_[np.array(x), np.array(y)])
update_mouse_graph(x, y, plots, edges)
return scatter
animation = FuncAnimation(
fig, func=animation_frame, frames=length_slider, interval=100,
)
ax.set_title("Positions across time for centered data")
ax.set_ylim(-100, 60)
ax.set_xlim(-60, 60)
ax.set_xlabel("x")
ax.set_ylabel("y")
video = animation.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
### 3. Pass data through all models
The figure above displays exactly the same data as the multi time-point scatter plot, but in the form of a video (one training instance at the time).
%% Cell type:markdown id: tags:
### 5. Visualize speeds
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
encodings = encoder.predict(preprocessed_data)
groupings = grouper.predict(preprocessed_data)
reconstrs = gmvaep.predict(preprocessed_data)
```
%% Cell type:markdown id: tags:
### 4. Evaluate reconstruction
### 6. Visualize acceleration
%% Cell type:code id: tags:
``` python
......@@ -220,11 +311,11 @@
```
%% Cell type:markdown id: tags:
### 5. Evaluate latent space
### 7. Visualize jerk
%% Cell type:code id: tags:
``` python
......
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:markdown id: tags:
# deepOF model evaluation
%% Cell type:markdown id: tags:
Given a dataset and a trained model, this notebook allows the user to
* Load and inspect the different models (encoder, decoder, grouper, gmvaep)
* Visualize reconstruction quality for a given model
* Visualize a static latent space
* Visualize trajectories on the latent space for a given video
* sample from the latent space distributions and generate video clips showcasing generated data
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import deepof.data
import deepof.utils
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from ipywidgets import interact
from IPython import display
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact
```
%% Cell type:markdown id: tags:
### 1. Define and run project
%% Cell type:code id: tags:
``` python
path = os.path.join("..", "..", "Desktop", "deepoftesttemp")
trained_network = os.path.join("..", "..", "Desktop")
```
%% Cell type:code id: tags:
``` python
%%time
proj = deepof.data.project(
path=path,
smooth_alpha=0.99,
exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip", "Tail_base"],
arena_dims=[380],
)
```
%% Cell type:code id: tags:
``` python
%%time
proj = proj.run(verbose=True)
print(proj)
```
%% Cell type:markdown id: tags:
### 2. Load pretrained deepof model
%% Cell type:code id: tags:
``` python
coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=24)
```
%% Cell type:code id: tags:
``` python
# Set model parameters
encoding=6
loss="ELBO"
k=25
pheno=0
predictor=0
```
%% Cell type:code id: tags:
``` python
encoder, decoder, grouper, gmvaep = deepof.models.SEQ_2_SEQ_GMVAE(
loss=loss,
number_of_components=k,
compile_model=True,
encoding=encoding,
predictor=predictor,
phenotype_prediction=pheno,
).build(preprocessed_data.shape)[:4]
gmvaep.load_weights(
os.path.join(
trained_network, [i for i in os.listdir(trained_network) if i.endswith("h5")][0]
)
)
```
%% Cell type:code id: tags:
``` python
# Uncomment to see model summaries
# encoder.summary()
# decoder.summary()
# grouper.summary()
# gmvaep.summary()
```
%% Cell type:code id: tags:
``` python
# Uncomment to plot model structure
def plot_model(model, name):
tf.keras.utils.plot_model(
model,
to_file=os.path.join(
path,
"deepof_{}_{}.png".format(name, datetime.now().strftime("%Y%m%d-%H%M%S")),
),
show_shapes=True,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=True,
dpi=200,
)
# plot_model(encoder, "encoder")
# plot_model(decoder, "decoder")
# plot_model(grouper, "grouper")
# plot_model(gmvaep, "gmvaep")
```
%% Cell type:markdown id: tags:
### 4. Pass data through all models
%% Cell type:code id: tags:
``` python
encodings = encoder.predict(preprocessed_data)
groupings = grouper.predict(preprocessed_data)
reconstrs = gmvaep.predict(preprocessed_data)
```
%% Cell type:markdown id: tags:
### 5. Evaluate reconstruction (to be incorporated into deepof.evaluate)
%% Cell type:code id: tags:
``` python
# Fit a scaler to the data, to back-transform reconstructions later
scaler = StandardScaler().fit(preprocessed_data[:, 0, :])
```
%% Cell type:code id: tags:
``` python
# Rescale reconstructions
rescaled_reconstructions = scaler.transform(
reconstrs.reshape(reconstrs.shape[0] * reconstrs.shape[1], reconstrs.shape[2])
)
rescaled_reconstructions = rescaled_reconstructions.reshape(reconstrs.shape)
```
%% Cell type:code id: tags:
``` python
# Display a video with the original data superimposed with the reconstructions
```
%% Cell type:markdown id: tags:
### 6. Evaluate latent space (to be incorporated into deepof.evaluate)
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment