Commit 46fd1898 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for preprocess.py

parent 7cf5b0b8
......@@ -8,8 +8,8 @@ test:
- pip install coverage
- pip install -r ./deepof/requirements.txt
- pip install -e deepof/
- coverage run -m pytest
- coverage report -m
- coverage run --source deepof -m pytest
- coverage report -m --omit deepof/setup.py
- coverage xml -o deepof_cov.xml
artifacts:
reports:
......
......@@ -4,12 +4,49 @@ from itertools import combinations
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import networkx as nx
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
# Connectivity for DLC models
def connect_mouse_topview(animal_id=None) -> nx.Graph:
"""Creates a nx.Graph object with the connectivity of the bodyparts in the
DLC topview model for a single mouse. Used later for angle computing, among others
Parameters:
- animal_id (str): if more than one animal is tagged,
specify the animal identyfier as a string
Returns:
- connectivity (nx.Graph)"""
connectivity = {
"Nose": ["Left_ear", "Right_ear", "Spine_1"],
"Left_ear": ["Right_ear", "Spine_1"],
"Right_ear": ["Spine_1"],
"Spine_1": ["Center", "Left_fhip", "Right_fhip"],
"Center": ["Left_fhip", "Right_fhip", "Spine_2", "Left_bhip", "Right_bhip"],
"Spine_2": ["Left_bhip", "Right_bhip", "Tail_base"],
"Tail_base": ["Tail_1", "Left_bhip", "Right_bhip"],
"Tail_1": ["Tail_2"],
"Tail_2": ["Tail_tip"],
}
connectivity = nx.Graph(connectivity)
if animal_id:
mapping = {
node: "{}_{}".format(animal_id, node) for node in connectivity.nodes()
}
nx.relabel_nodes(connectivity, mapping, copy=False)
return connectivity
# Helper functions
@tf.function
def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000):
......
......@@ -13,6 +13,7 @@ import networkx as nx
from deepof.utils import *
from deepof.visuals import *
from deepof.model_utils import connect_mouse_topview
class project:
......@@ -27,7 +28,7 @@ class project:
video_format=".mp4",
table_format=".h5",
path=".",
exp_conditions=False,
exp_conditions=None,
subset_condition=None,
arena="circular",
smooth_alpha=0.1,
......@@ -35,7 +36,7 @@ class project:
distances="All",
ego=False,
angles=True,
connectivity=None,
model="mouse_topview",
):
self.path = path
......@@ -57,9 +58,11 @@ class project:
self.distances = distances
self.ego = ego
self.angles = angles
self.connectivity = connectivity
self.scales = self.get_scale
model_dict = {"mouse_topview": connect_mouse_topview()}
self.connectivity = model_dict[model]
def __str__(self):
if self.exp_conditions:
return "DLC analysis of {} videos across {} conditions".format(
......@@ -79,15 +82,18 @@ class project:
if verbose:
print("Loading trajectories...")
tab_dict = {}
if self.table_format == ".h5":
table_dict = {
tab_dict = {
re.findall("(.*?)_", tab)[0]: pd.read_hdf(
os.path.join(self.table_path, tab), dtype=float
)
for tab in self.tables
}
elif self.table_format == ".csv":
table_dict = {}
for tab in self.tables:
head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2)
data = pd.read_csv(
......@@ -104,18 +110,18 @@ class project:
],
names=["scorer", "bodyparts", "coords"],
)
table_dict[re.findall("(.*?)_", tab)[0]] = data
tab_dict[re.findall("(.*?)_", tab)[0]] = data
lik_dict = defaultdict()
for key, value in table_dict.items():
for key, value in tab_dict.items():
x = value.xs("x", level="coords", axis=1, drop_level=False)
y = value.xs("y", level="coords", axis=1, drop_level=False)
lik: pd.DataFrame = value.xs(
"likelihood", level="coords", axis=1, drop_level=True
)
table_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1)
lik_dict[key] = lik
if self.smooth_alpha:
......@@ -123,19 +129,19 @@ class project:
if verbose:
print("Smoothing trajectories...")
for key, tab in table_dict.items():
for key, tab in tab_dict.items():
cols = tab.columns
smooth = pd.DataFrame(
smooth_mult_trajectory(np.array(tab), alpha=self.smooth_alpha)
)
smooth.columns = cols
table_dict[key] = smooth
tab_dict[key] = smooth
for key, tab in table_dict.items():
table_dict[key] = tab[tab.columns.levels[0][0]]
for key, tab in tab_dict.items():
tab_dict[key] = tab[tab.columns.levels[0][0]]
if self.subset_condition:
for key, value in table_dict.items():
for key, value in tab_dict.items():
lablist = [
b
for b in value.columns.levels[0]
......@@ -152,9 +158,9 @@ class project:
tab.columns = tabcols
table_dict[key] = tab
tab_dict[key] = tab
return table_dict, lik_dict
return tab_dict, lik_dict
@property
def get_scale(self):
......@@ -182,7 +188,7 @@ class project:
return np.array(scales)
def get_distances(self, table_dict, verbose):
def get_distances(self, tab_dict, verbose=False):
"""Computes the distances between all selected bodyparts over time.
If ego is provided, it only returns distances to a specified bodypart"""
......@@ -191,17 +197,17 @@ class project:
nodes = self.distances
if nodes == "All":
nodes = table_dict[list(table_dict.keys())[0]].columns.levels[0]
nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0]
assert [
i in list(table_dict.values())[0].columns.levels[0] for i in nodes
i in list(tab_dict.values())[0].columns.levels[0] for i in nodes
], "Nodes should correspond to existent bodyparts"
scales = self.scales[:, 2:]
distance_dict = {
key: bpart_distance(tab, scales[i, 1], scales[i, 0],)
for i, (key, tab) in enumerate(table_dict.items())
for i, (key, tab) in enumerate(tab_dict.items())
}
for key in distance_dict.keys():
......@@ -217,7 +223,7 @@ class project:
return distance_dict
def get_angles(self, table_dict, verbose):
def get_angles(self, tab_dict, verbose):
"""
Computes all the angles between adjacent bodypart trios per video and per frame in the data.
......@@ -233,12 +239,11 @@ class project:
if verbose:
print("Computing angles...")
bp_net = nx.Graph(self.connectivity)
cliques = nx.enumerate_all_cliques(bp_net)
cliques = nx.enumerate_all_cliques(self.connectivity)
cliques = [i for i in cliques if len(i) == 3]
angle_dict = {}
for key, tab in table_dict.items():
for key, tab in tab_dict.items():
dats = []
for clique in cliques:
......
......@@ -12,10 +12,57 @@ import deepof.preprocess
import pytest
@given(table_type=st.integers(min_value=0, max_value=2))
def test_project_init(table_type):
@settings(deadline=None)
@given(
table_type=st.integers(min_value=0, max_value=2),
arena_type=st.integers(min_value=0, max_value=1),
)
def test_project_init(table_type, arena_type):
table_type = [".h5", ".csv", ".foo"][table_type]
arena_type = ["circular", "foo"][arena_type]
if arena_type == "foo":
with pytest.raises(NotImplementedError):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena=arena_type,
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=table_type,
)
else:
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena=arena_type,
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=table_type,
)
if table_type != ".foo" and arena_type != "foo":
assert type(prun) == deepof.preprocess.project
assert type(prun.load_tables(verbose=True)) == tuple
assert type(prun.get_scale) == np.ndarray
print(prun)
elif table_type == ".foo" and arena_type != "foo":
with pytest.raises(NotImplementedError):
prun.load_tables(verbose=True)
@settings(deadline=None)
@given(
nodes=st.integers(min_value=0, max_value=1),
ego=st.integers(min_value=0, max_value=2),
)
def test_get_distances(nodes, ego):
nodes = ["All", ["Center", "Nose", "Tail_base"]][nodes]
ego = [False, "Center", "Nose"][ego]
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
......@@ -23,13 +70,59 @@ def test_project_init(table_type):
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=table_type,
table_format=".h5",
distances=nodes,
ego=ego,
)
if table_type != ".foo":
assert type(prun) == deepof.preprocess.project
assert type(prun.load_tables(verbose=True)) == tuple
print(prun)
else:
with pytest.raises(NotImplementedError):
prun.load_tables(verbose=True)
prun = prun.get_distances(prun.load_tables()[0], verbose=True)
assert type(prun) == dict
@settings(deadline=None)
@given(
nodes=st.integers(min_value=0, max_value=1),
ego=st.integers(min_value=0, max_value=2),
)
def test_get_angles(nodes, ego):
nodes = ["All", ["Center", "Nose", "Tail_base"]][nodes]
ego = [False, "Center", "Nose"][ego]
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
video_format=".mp4",
table_format=".h5",
distances=nodes,
ego=ego,
)
prun = prun.get_angles(prun.load_tables()[0], verbose=True)
assert type(prun) == dict
@settings(deadline=None)
@given(
nodes=st.integers(min_value=0, max_value=1),
ego=st.integers(min_value=0, max_value=2),
)
def test_run(nodes, ego):
nodes = ["All", ["Center", "Nose", "Tail_base"]][nodes]
ego = [False, "Center", "Nose"][ego]
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
video_format=".mp4",
table_format=".h5",
distances=nodes,
ego=ego,
).run(verbose=True)
assert type(prun) == deepof.preprocess.coordinates
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment