Commit fca79d55 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added _scaler method to table_dict, to retrieve the scaler used to preprocess data

parent aadc4d27
......@@ -997,6 +997,7 @@ class table_dict(dict):
self._arena_dims = arena_dims
self._propagate_labels = propagate_labels
self._propagate_annotations = propagate_annotations
self._scaler = None
def filter_videos(self, keys: list) -> Table_dict:
"""Returns a subset of the original table_dict object, containing only the specified keys. Useful, for example,
......@@ -1156,15 +1157,15 @@ class table_dict(dict):
print("Scaling data...")
if scale == "standard":
scaler = StandardScaler()
self._scaler = StandardScaler()
elif scale == "minmax":
scaler = MinMaxScaler()
self._scaler = MinMaxScaler()
else:
raise ValueError(
"Invalid scaler. Select one of standard, minmax or None"
) # pragma: no cover
X_train = scaler.fit_transform(
X_train = self._scaler.fit_transform(
X_train.reshape(-1, X_train.shape[-1])
).reshape(X_train.shape)
......@@ -1173,7 +1174,7 @@ class table_dict(dict):
assert np.allclose(np.nan_to_num(np.std(X_train), nan=1), 1)
if test_videos:
X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
X_test = self._scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
X_test.shape
)
......
......@@ -3,10 +3,17 @@
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:markdown id: tags:
# deepOF data exploration
%% Cell type:markdown id: tags:
......@@ -19,17 +26,10 @@
* Visualize training instances as video clips with interactive configurations
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:code id: tags:
``` python
import deepof.data
import deepof.utils
import numpy as np
import pandas as pd
import tensorflow as tf
......@@ -46,15 +46,22 @@
### 1. Define and run project
%% Cell type:code id: tags:
``` python
exclude_bodyparts = tuple([""])
```
%% Cell type:code id: tags:
``` python
# Use deepof to load a project
proj = deepof.data.project(
path="../../Desktop/deepoftesttemp/",
arena_dims=[380],
arena_detection="rule-based",
exclude_bodyparts=exclude_bodyparts,
interpolate_outliers=True,
).run()
```
%% Cell type:markdown id: tags:
......@@ -191,10 +198,12 @@
The figure above is a multi time-point scatter plot. The time_slider allows you to scroll across the video, and the length_slider selects the number of time-points to include. The idea is to intuitively visualize the data that goes into a training instance for a given preprocessing setting.
%% Cell type:code id: tags:
``` python
# Auxiliary animation functions
def plot_mouse_graph(instant_x, instant_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
for edge in edges:
(temp_plot,) = ax.plot(
......@@ -226,10 +235,14 @@
def animate_mice_across_time(time_slider, length_slider):
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
edges = deepof.utils.connect_mouse_topview()
for bpart in exclude_bodyparts:
edges.remove_node(bpart)
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
edges = edges.edges()
data = position_coords[random_exp].loc[
......@@ -237,21 +250,21 @@
]
data["Center", "x"] = 0
data["Center", "y"] = 0
init_x = data.xs("x", level="coords", axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level="coords", axis=1, drop_level=False).iloc[0, :]
init_x = data.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
plots = plot_mouse_graph(init_x, init_y, ax, edges)
scatter = ax.scatter(x=np.array(init_x), y=np.array(init_y), color="#006699",)
# Update data in main plot
def animation_frame(i):
# Update scatter plot
x = data.xs("x", level="coords", axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level="coords", axis=1, drop_level=False).iloc[i, :]
x = data.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
scatter.set_offsets(np.c_[np.array(x), np.array(y)])
update_mouse_graph(x, y, plots, edges)
return scatter
......
......@@ -3,10 +3,17 @@
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:markdown id: tags:
# deepOF model evaluation
%% Cell type:markdown id: tags:
......@@ -52,20 +59,22 @@
%% Cell type:code id: tags:
``` python
path = os.path.join("..", "..", "Desktop", "deepoftesttemp")
trained_network = os.path.join("..", "..", "Desktop")
exclude_bodyparts = ["Tail_1", "Tail_2", "Tail_tip", "Tail_base"]
window_size = 24
```
%% Cell type:code id: tags:
``` python
%%time
proj = deepof.data.project(
path=path,
smooth_alpha=0.99,
exclude_bodyparts=["Tail_1", "Tail_2", "Tail_tip", "Tail_base"],
exclude_bodyparts=exclude_bodyparts,
arena_dims=[380],
)
```
%% Cell type:code id: tags:
......@@ -82,11 +91,11 @@
%% Cell type:code id: tags:
``` python
coords = proj.get_coords(center="Center", align="Spine_1", align_inplace=True)
preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=24)
preprocessed_data,_,_,_ = coords.preprocess(test_videos=0, window_step=5, window_size=window_size)
```
%% Cell type:code id: tags:
``` python
......@@ -115,10 +124,34 @@
trained_network, [i for i in os.listdir(trained_network) if i.endswith("h5")][0]
)
)
```
%%%% Output: error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-58-ea431bf97d05> in <module>
10 gmvaep.load_weights(
11 os.path.join(
---> 12 trained_network, [i for i in os.listdir(trained_network) if i.endswith("h5")][0]
13 )
14 )
~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in load_weights(self, filepath, by_name, skip_mismatch, options)
2232 f, self.layers, skip_mismatch=skip_mismatch)
2233 else:
-> 2234 hdf5_format.load_weights_from_hdf5_group(f, self.layers)
2235
2236 def _updated_config(self):
~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/saving/hdf5_format.py in load_weights_from_hdf5_group(f, layers)
686 'containing ' + str(len(layer_names)) +
687 ' layers into a model with ' + str(len(filtered_layers)) +
--> 688 ' layers.')
689
690 # We batch weight value assignments in a single backend call
ValueError: You are trying to load a weight file containing 15 layers into a model with 14 layers.
%% Cell type:code id: tags:
``` python
# Uncomment to see model summaries
# encoder.summary()
......@@ -178,31 +211,147 @@
%% Cell type:code id: tags:
``` python
# Rescale reconstructions
rescaled_reconstructions = scaler.transform(
rescaled_reconstructions = scaler.inverse_transform(
reconstrs.reshape(reconstrs.shape[0] * reconstrs.shape[1], reconstrs.shape[2])
)
rescaled_reconstructions = rescaled_reconstructions.reshape(reconstrs.shape)
```
%% Cell type:code id: tags:
``` python
# Auxiliary animation functions
def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
rec_plots = []
for edge in edges:
(temp_plot,) = ax.plot(
[float(instant_x[edge[0]]), float(instant_x[edge[1]])],
[float(instant_y[edge[0]]), float(instant_y[edge[1]])],
color="#006699",
)
(temp_rec_plot,) = ax.plot(
[float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],
[float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],
color="#006699",
)
plots.append(temp_plot)
rec_plots.append(temp_rec_plot)
return plots
def update_mouse_graph(x, y, plots, edges):
"""Updates the graph plot to enable animation"""
for plot, edge in zip(plots, edges):
plot.set_data(
[float(x[edge[0]]), float(x[edge[1]])],
[float(y[edge[0]]), float(y[edge[1]])],
)
```
%% Cell type:code id: tags:
``` python
# Display a video with the original data superimposed with the reconstructions
random_exp = np.random.choice(list(coords.keys()), 1)[0]
@interact(time_slider=(0.0, 15000, 500), length_slider=(0, 1000, 100))
def animate_mice_across_time(time_slider, length_slider):
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
edges = deepof.utils.connect_mouse_topview()
for bpart in exclude_bodyparts:
edges.remove_node(bpart)
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
edges = edges.edges()
data = coords[random_exp].loc[time_slider : time_slider + length_slider - 1, :]
data_rec = gmvaep.predict(
coords.filter_videos([random_exp]).preprocess(
test_videos=0, window_step=5, window_size=window_size
)[0]
)
data_rec = pd.DataFrame(scaler.inverse_transform(data_rec[:, 24 // 2, :]))
data_rec.columns = data.columns
data["Center", "x"] = 0
data["Center", "y"] = 0
data_rec["Center", "x"] = 0
data_rec["Center", "y"] = 0
init_x = data.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_y = data.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
init_rec_x = data_rec.xs("x", level=1, axis=1, drop_level=False).iloc[0, :]
init_rec_y = data_rec.xs("y", level=1, axis=1, drop_level=False).iloc[0, :]
plots = plot_mouse_graph(init_x, init_y, init_rec_x, init_rec_y, ax, edges)
scatter = ax.scatter(x=np.array(init_x), y=np.array(init_y), color="#006699",)
rec_scatter = ax.scatter(
x=np.array(init_rec_x), y=np.array(init_rec_y), color="#006699",
)
# Update data in main plot
def animation_frame(i):
# Update scatter plot
x = data.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
y = data.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
rec_x = data_rec.xs("x", level=1, axis=1, drop_level=False).iloc[i, :]
rec_y = data_rec.xs("y", level=1, axis=1, drop_level=False).iloc[i, :]
scatter.set_offsets(np.c_[np.array(x), np.array(y)])
scatter.set_offsets(np.c_[np.array(rec_x), np.array(rec_y)])
update_mouse_graph(x, y, plots, edges)
return scatter
animation = FuncAnimation(
fig, func=animation_frame, frames=length_slider, interval=100,
)
ax.set_title("Positions across time for centered data")
ax.set_ylim(-100, 60)
ax.set_xlim(-60, 60)
ax.set_xlabel("x")
ax.set_ylabel("y")
video = animation.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
### 6. Evaluate latent space (to be incorporated into deepof.evaluate)
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment