Commit 156f1af2 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented rotation alignment in preprocess.py

parent a201ddb9
This diff is collapsed.
This diff is collapsed.
...@@ -232,7 +232,7 @@ DLC_social_1_coords = DLC_social_1.run(verbose=True) ...@@ -232,7 +232,7 @@ DLC_social_1_coords = DLC_social_1.run(verbose=True)
DLC_social_2_coords = DLC_social_2.run(verbose=True) DLC_social_2_coords = DLC_social_2.run(verbose=True)
# Coordinates for training data # Coordinates for training data
coords1 = DLC_social_1_coords.get_coords(center="B_Center", polar=False) coords1 = DLC_social_1_coords.get_coords(center="B_Center", align="B_Nose")
distances1 = DLC_social_1_coords.get_distances() distances1 = DLC_social_1_coords.get_distances()
angles1 = DLC_social_1_coords.get_angles() angles1 = DLC_social_1_coords.get_angles()
coords_distances1 = merge_tables(coords1, distances1) coords_distances1 = merge_tables(coords1, distances1)
...@@ -241,7 +241,7 @@ dists_angles1 = merge_tables(distances1, angles1) ...@@ -241,7 +241,7 @@ dists_angles1 = merge_tables(distances1, angles1)
coords_dist_angles1 = merge_tables(coords1, distances1, angles1) coords_dist_angles1 = merge_tables(coords1, distances1, angles1)
# Coordinates for validation data # Coordinates for validation data
coords2 = DLC_social_2_coords.get_coords(center="B_Center", polar=False) coords2 = DLC_social_2_coords.get_coords(center="B_Center", align="B_Nose")
distances2 = DLC_social_2_coords.get_distances() distances2 = DLC_social_2_coords.get_distances()
angles2 = DLC_social_2_coords.get_angles() angles2 = DLC_social_2_coords.get_angles()
coords_distances2 = merge_tables(coords2, distances2) coords_distances2 = merge_tables(coords2, distances2)
...@@ -258,6 +258,7 @@ input_dict_train = { ...@@ -258,6 +258,7 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
"dists": distances1.preprocess( "dists": distances1.preprocess(
window_size=11, window_size=11,
...@@ -266,6 +267,7 @@ input_dict_train = { ...@@ -266,6 +267,7 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
"angles": angles1.preprocess( "angles": angles1.preprocess(
window_size=11, window_size=11,
...@@ -274,6 +276,7 @@ input_dict_train = { ...@@ -274,6 +276,7 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
"coords+dist": coords_distances1.preprocess( "coords+dist": coords_distances1.preprocess(
window_size=11, window_size=11,
...@@ -282,6 +285,7 @@ input_dict_train = { ...@@ -282,6 +285,7 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
"coords+angle": coords_angles1.preprocess( "coords+angle": coords_angles1.preprocess(
window_size=11, window_size=11,
...@@ -290,6 +294,7 @@ input_dict_train = { ...@@ -290,6 +294,7 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
"dists+angle": dists_angles1.preprocess( "dists+angle": dists_angles1.preprocess(
window_size=11, window_size=11,
...@@ -298,6 +303,7 @@ input_dict_train = { ...@@ -298,6 +303,7 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
"coords+dist+angle": coords_dist_angles1.preprocess( "coords+dist+angle": coords_dist_angles1.preprocess(
window_size=11, window_size=11,
...@@ -306,54 +312,60 @@ input_dict_train = { ...@@ -306,54 +312,60 @@ input_dict_train = {
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
align=True,
), ),
} }
input_dict_val = { input_dict_val = {
"coords": coords2.preprocess( "coords": coords2.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=10,
scale=True, scale=True,
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
"dists": distances2.preprocess( "dists": distances2.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=10,
scale=True, scale=True,
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
"angles": angles2.preprocess( "angles": angles2.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=10,
scale=True, scale=True,
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
"coords+dist": coords_distances2.preprocess( "coords+dist": coords_distances2.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=10,
scale=True, scale=True,
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
"coords+angle": coords_angles2.preprocess( "coords+angle": coords_angles2.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=10,
scale=True, scale=True,
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
"dists+angle": dists_angles2.preprocess( "dists+angle": dists_angles2.preprocess(
window_size=11, window_size=11,
...@@ -363,15 +375,17 @@ input_dict_val = { ...@@ -363,15 +375,17 @@ input_dict_val = {
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
"coords+dist+angle": coords_dist_angles2.preprocess( "coords+dist+angle": coords_dist_angles2.preprocess(
window_size=11, window_size=11,
window_step=1, window_step=10,
scale=True, scale=True,
random_state=42, random_state=42,
filter="gaussian", filter="gaussian",
sigma=55, sigma=55,
shuffle=True, shuffle=True,
align=True,
), ),
} }
......
...@@ -309,7 +309,9 @@ class coordinates: ...@@ -309,7 +309,9 @@ class coordinates:
else: else:
return "DLC analysis of {} videos".format(len(self._videos)) return "DLC analysis of {} videos".format(len(self._videos))
def get_coords(self, center="arena", polar=False, speed=0, length=None): def get_coords(
self, center="arena", polar=False, speed=0, length=None, align=False
):
tabs = deepcopy(self._tables) tabs = deepcopy(self._tables)
if polar: if polar:
...@@ -384,6 +386,21 @@ class coordinates: ...@@ -384,6 +386,21 @@ class coordinates:
"00:00:00", length, periods=tab.shape[0] + 1, closed="left" "00:00:00", length, periods=tab.shape[0] + 1, closed="left"
) )
if align:
assert (
align in list(tabs.values())[0].columns.levels[0]
), "align must be set to the name of a bodypart"
for key, tab in tabs.items():
# Bring forward the column to align
columns = [i for i in tab.columns if align not in i]
columns = [
(align, ("phi" if polar else "x")),
(align, ("rho" if polar else "y")),
] + columns
tabs[key] = tab[columns]
return table_dict( return table_dict(
tabs, tabs,
"coords", "coords",
...@@ -524,17 +541,17 @@ class table_dict(dict): ...@@ -524,17 +541,17 @@ class table_dict(dict):
self, self,
window_size=1, window_size=1,
window_step=1, window_step=1,
scale=True, scale="standard",
test_proportion=0, test_proportion=0,
random_state=None, random_state=None,
verbose=False, verbose=False,
filter=None, filter=None,
sigma=None, sigma=None,
shift=0, shift=0,
standard_scaler=True,
shuffle=False, shuffle=False,
align=False,
): ):
"""Builds a sliding window. If desired, splits train and test and """Builds a sliding window. If specified, splits train and test and
Z-scores the data using sklearn's standard scaler""" Z-scores the data using sklearn's standard scaler"""
X_train = self.get_training_set() X_train = self.get_training_set()
...@@ -550,16 +567,20 @@ class table_dict(dict): ...@@ -550,16 +567,20 @@ class table_dict(dict):
if verbose: if verbose:
print("Scaling data...") print("Scaling data...")
if standard_scaler: if scale == "standard":
scaler = StandardScaler() scaler = StandardScaler()
else: elif scale == "minmax":
scaler = MinMaxScaler() scaler = MinMaxScaler()
else:
raise ValueError(
"Invalid scaler. Select one of standard, minmax or None"
)
X_train = scaler.fit_transform( X_train = scaler.fit_transform(
X_train.reshape(-1, X_train.shape[-1]) X_train.reshape(-1, X_train.shape[-1])
).reshape(X_train.shape) ).reshape(X_train.shape)
if standard_scaler: if scale == "standard":
assert np.allclose(np.mean(X_train), 0) assert np.allclose(np.mean(X_train), 0)
assert np.allclose(np.std(X_train), 1) assert np.allclose(np.std(X_train), 1)
...@@ -573,6 +594,9 @@ class table_dict(dict): ...@@ -573,6 +594,9 @@ class table_dict(dict):
X_train = rolling_window(X_train, window_size, window_step) X_train = rolling_window(X_train, window_size, window_step)
if align:
X_train = align_trajectories(X_train)
if filter == "gaussian": if filter == "gaussian":
r = range(-int(window_size / 2), int(window_size / 2) + 1) r = range(-int(window_size / 2), int(window_size / 2) + 1)
r = [i - shift for i in r] r = [i - shift for i in r]
......
...@@ -11,6 +11,7 @@ import pims ...@@ -11,6 +11,7 @@ import pims
import re import re
import scipy import scipy
import seaborn as sns import seaborn as sns
from copy import deepcopy
from itertools import cycle, combinations, product from itertools import cycle, combinations, product
from joblib import Parallel, delayed from joblib import Parallel, delayed
from numba import jit from numba import jit
...@@ -87,6 +88,32 @@ def angle_trio(array, degrees=False): ...@@ -87,6 +88,32 @@ def angle_trio(array, degrees=False):
return np.array([angle(a, b, c), angle(a, c, b), angle(b, a, c),]) return np.array([angle(a, b, c), angle(a, c, b), angle(b, a, c),])
def rotate(p, angles, origin=np.array([0, 0])):
R = np.array([[np.cos(angles), -np.sin(angles)], [np.sin(angles), np.cos(angles)]])
o = np.atleast_2d(origin)
p = np.atleast_2d(p)
return np.squeeze((R @ (p.T - o.T) + o.T).T)
def align_trajectories(data):
data = deepcopy(data)
center_time = (data.shape[1] - 1) // 2
angles = np.arctan2(data[:, center_time, 0], data[:, center_time, 1])
aligned_trajs = np.zeros(data.shape)
for frame in range(data.shape[0]):
aligned_trajs[frame] = rotate(
data[frame].reshape([data.shape[1] * data.shape[2] // 2, 2]), angles[frame],
).reshape(data.shape[1:])
return aligned_trajs
def smooth_boolean_array(a): def smooth_boolean_array(a):
"""Returns a boolean array in which isolated appearances of a feature are smoothened""" """Returns a boolean array in which isolated appearances of a feature are smoothened"""
for i in range(1, len(a) - 1): for i in range(1, len(a) - 1):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment