Commit fca79d55 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added _scaler method to table_dict, to retrieve the scaler used to preprocess data

parent aadc4d27
......@@ -997,6 +997,7 @@ class table_dict(dict):
self._arena_dims = arena_dims
self._propagate_labels = propagate_labels
self._propagate_annotations = propagate_annotations
self._scaler = None
def filter_videos(self, keys: list) -> Table_dict:
"""Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
......@@ -1156,15 +1157,15 @@ class table_dict(dict):
print("Scaling data...")
if scale == "standard":
scaler = StandardScaler()
self._scaler = StandardScaler()
elif scale == "minmax":
scaler = MinMaxScaler()
self._scaler = MinMaxScaler()
else:
raise ValueError(
"Invalid scaler. Select one of standard, minmax or None"
) # pragma: no cover
X_train = scaler.fit_transform(
X_train = self._scaler.fit_transform(
X_train.reshape(-1, X_train.shape[-1])
).reshape(X_train.shape)
......@@ -1173,7 +1174,7 @@ class table_dict(dict):
assert np.allclose(np.nan_to_num(np.std(X_train), nan=1), 1)
if test_videos:
X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
X_test = self._scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
X_test.shape
)
......
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:markdown id: tags:
# deepOF data exploration
%% Cell type:markdown id: tags:
Given a dataset, this notebook allows the user to
* Load and process the dataset using deepof.data
* Visualize data quality with interactive plots
* Visualize training instances as multi-timepoint scatter plots with interactive configurations
* Visualize training instances as video clips with interactive configurations
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
import deepof.data
import deepof.utils
import numpy as np
import pandas as pd
import tensorflow as tf
from ipywidgets import interact
from IPython import display
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import seaborn as sns
```
%% Cell type:markdown id: tags:
### 1. Define and run project
%% Cell type:code id: tags:
``` python
exclude_bodyparts = tuple([""])
```
%% Cell type:code id: tags:
``` python
# Use deepof to load a project
proj = deepof.data.project(
path="../../Desktop/deepoftesttemp/",
arena_dims=[380],
arena_detection="rule-based",
exclude_bodyparts=exclude_bodyparts,
interpolate_outliers=True,
).run()
```
%% Output
Loading trajectories...
Smoothing trajectories...
Interpolating outliers...
Iterative imputation of ocluded bodyparts...
Computing distances...
Computing angles...
Done!
%% Cell type:markdown id: tags:
### 2. Inspect dataset quality
%% Cell type:code id: tags:
``` python
all_quality = pd.concat([tab for tab in proj.get_quality().values()])
```
%% Cell type:code id: tags:
``` python
all_quality.boxplot(rot=45)
plt.ylim(0.99985, 1.00001)
plt.show()
```
%% Output
%% Cell type:code id: tags:
``` python
@interact(quality_top=(0.0, 1.0, 0.01))
def low_quality_tags(quality_top):
pd.DataFrame(
pd.melt(all_quality)
.groupby("bodyparts")
.value.apply(lambda y: sum(y < quality_top) / len(y) * 100)
).sort_values(by="value", ascending=False).plot.bar(rot=45)
plt.xlabel("body part")
plt.ylabel("Tags with quality under {} (%)".format(quality_top * 100))
plt.tight_layout()
plt.legend([])
plt.show()
```
%% Output
%% Cell type:markdown id: tags:
In the cell above, you see the percentage of labels per body part which have a quality lower than the selected value (0.50 by default) **before** preprocessing. The values are taken directly from DeepLabCut.
%% Cell type:markdown id: tags:
### 3. Get coordinates, distances and angles
%% Cell type:markdown id: tags:
And get speed, acceleration and jerk for each
%% Cell type:code id: tags:
``` python
# Get coordinates, speeds, accelerations and jerks for positions
position_coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
position_speeds = proj.get_coords(center="Center", speed=1)
position_accels = proj.get_coords(center="Center", speed=2)
position_jerks = proj.get_coords(center="Center", speed=3)
```
%% Cell type:code id: tags:
``` python
# Get coordinates, speeds, accelerations and jerks for distances
distance_coords = proj.get_distances()
distance_speeds = proj.get_distances(speed=1)
distance_accels = proj.get_distances(speed=2)
distance_jerks = proj.get_distances(speed=3)
```
%% Cell type:code id: tags:
``` python
# Get coordinates, speeds, accelerations and jerks for angles
angle_coords = proj.get_angles()
angle_speeds = proj.get_angles(speed=1)
angle_accels = proj.get_angles(speed=2)
angle_jerks = proj.get_angles(speed=3)
```
%% Cell type:markdown id: tags:
### 4. Display training instances
%% Cell type:code id: tags:
``` python
random_exp = np.random.choice(list(position_coords.keys()), 1)[0]
@interact(time_slider=(0.0, 15000, 25), length_slider=(10, 100, 5))
def plot_mice_across_time(time_slider, length_slider):
plt.figure(figsize=(10, 10))
for bpart in position_coords[random_exp].columns.levels[0]:
if bpart != "Center":
sns.scatterplot(
data=position_coords[random_exp].loc[
time_slider : time_slider + length_slider - 1, bpart
],
x="x",
y="y",
label=bpart,
palette=sns.color_palette("tab10"),
)
plt.title("Positions across time for centered data")
plt.legend(
fontsize=15,
bbox_to_anchor=(1.5, 1),
title="Body part",
title_fontsize=18,
shadow=False,
facecolor="white",
)
plt.ylim(-100, 60)
plt.xlim(-60, 60)
plt.show()
```
%% Output
%% Cell type:markdown id: tags:
The figure above is a multi time-point scatter plot. The time_slider allows you to scroll across the video, and the length_slider selects the number of time-points to include. The idea is to intuitively visualize the data that goes into a training instance for a given preprocessing setting.
%% Cell type:code id: tags:
``` python
# Auxiliary animation functions
def plot_mouse_graph(instant_x, instant_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
for edge in edges:
(temp_plot,) = ax.plot(
[float(instant_x[edge[0]]), float(instant_x[edge[1]])],
[float(instant_y[edge[0]]), float(instant_y[edge[1]])],
color="#006699",
)
plots.append(temp_plot)
return plots
def update_mouse_graph(x, y, plots, edges):
"""Updates the graph plot to enable animation"""
for plot, edge in zip(plots, edges):
plot.set_data(
[float(x[edge[0]]), float(x[edge[1]])],
[float(y[edge[0]]), float(y[edge[1]])],
)
```
%% Cell type:code id: tags:
``` python
random_exp = np.random.choice(list(position_coords.keys()), 1)[0]
@interact(time_slider=(0.0, 15000, 25), length_slider=(10, 100, 5))
def animate_mice_across_time(time_slider, length_slider):
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
edges = deepof.utils.connect_mouse_topview()
for bpart in exclude_bodyparts:
edges.remove_node(bpart)
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
edges = edges.edges()
data = position_coords[random_exp].loc[
time_slider : time_slider + length_slider - 1, :
]
data["Center", "x"] = 0
data["Center", "y"] = 0
init_x = data.xs("x", level="coords", axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level="coords", axis=1, drop_level=False).iloc[0, :]
init_x = data.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
plots = plot_mouse_graph(init_x, init_y, ax, edges)
scatter = ax.scatter(x=np.array(init_x), y=np.array(init_y), color="#006699",)
# Update data in main plot
def animation_frame(i):
# Update scatter plot
x = data.xs("x", level="coords", axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level="coords", axis=1, drop_level=False).iloc[i, :]
x = data.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
scatter.set_offsets(np.c_[np.array(x), np.array(y)])
update_mouse_graph(x, y, plots, edges)
return scatter
animation = FuncAnimation(
fig, func=animation_frame, frames=length_slider, interval=100,
)
ax.set_title("Positions across time for centered data")
ax.set_ylim(-100, 60)
ax.set_xlim(-60, 60)
ax.set_xlabel("x")
ax.set_ylabel("y")
video = animation.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
```
%% Output
%% Cell type:markdown id: tags:
The figure above displays exactly the same data as the multi time-point scatter plot, but in the form of a video (one training instance at the time).
%% Cell type:markdown id: tags:
### 5. Visualize speeds
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:markdown id: tags:
### 6. Visualize acceleration
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:markdown id: tags:
### 7. Visualize jerk
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
......