diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2a28c2b2398ab62acb5af0c6f7240390526496c8..ed7c3c14e09984f82b3680b3a9dd326f17a2f05d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -8,8 +8,8 @@ test: - pip install coverage - pip install -r ./deepof/requirements.txt - pip install -e deepof/ - - coverage run -m pytest - - coverage report -m + - coverage run --source deepof -m pytest + - coverage report -m --omit deepof/setup.py - coverage xml -o deepof_cov.xml artifacts: reports: diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 8c64c47ae96a4589afd1b56bb43ddd351e6ed269..49bea9a950b23b5092d36f573063902359851e38 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -4,12 +4,49 @@ from itertools import combinations from tensorflow.keras import backend as K from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer +import networkx as nx import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions tfpl = tfp.layers + +# Connectivity for DLC models +def connect_mouse_topview(animal_id=None) -> nx.Graph: + """Creates a nx.Graph object with the connectivity of the bodyparts in the + DLC topview model for a single mouse. Used later for angle computing, among others + + Parameters: + - animal_id (str): if more than one animal is tagged, + specify the animal identyfier as a string + + Returns: + - connectivity (nx.Graph)""" + + connectivity = { + "Nose": ["Left_ear", "Right_ear", "Spine_1"], + "Left_ear": ["Right_ear", "Spine_1"], + "Right_ear": ["Spine_1"], + "Spine_1": ["Center", "Left_fhip", "Right_fhip"], + "Center": ["Left_fhip", "Right_fhip", "Spine_2", "Left_bhip", "Right_bhip"], + "Spine_2": ["Left_bhip", "Right_bhip", "Tail_base"], + "Tail_base": ["Tail_1", "Left_bhip", "Right_bhip"], + "Tail_1": ["Tail_2"], + "Tail_2": ["Tail_tip"], + } + + connectivity = nx.Graph(connectivity) + + if animal_id: + mapping = { + node: "{}_{}".format(animal_id, node) for node in connectivity.nodes() + } + nx.relabel_nodes(connectivity, mapping, copy=False) + + return connectivity + + # Helper functions @tf.function def far_away_uniform_initialiser(shape, minval=0, maxval=15, iters=100000): diff --git a/deepof/preprocess.py b/deepof/preprocess.py index 3d8c4c3de9cdb615fbc8c79124237f182858af73..51e463ec2825b0a6e358fd1921e28b1bb7d0f37c 100644 --- a/deepof/preprocess.py +++ b/deepof/preprocess.py @@ -13,6 +13,7 @@ import networkx as nx from deepof.utils import * from deepof.visuals import * +from deepof.model_utils import connect_mouse_topview class project: @@ -27,7 +28,7 @@ class project: video_format=".mp4", table_format=".h5", path=".", - exp_conditions=False, + exp_conditions=None, subset_condition=None, arena="circular", smooth_alpha=0.1, @@ -35,7 +36,7 @@ class project: distances="All", ego=False, angles=True, - connectivity=None, + model="mouse_topview", ): self.path = path @@ -57,9 +58,11 @@ class project: self.distances = distances self.ego = ego self.angles = angles - self.connectivity = connectivity self.scales = self.get_scale + model_dict = {"mouse_topview": connect_mouse_topview()} + self.connectivity = model_dict[model] + def __str__(self): if self.exp_conditions: return "DLC analysis of {} videos across {} conditions".format( @@ -79,15 +82,18 @@ class project: if verbose: print("Loading trajectories...") + tab_dict = {} + if self.table_format == ".h5": - table_dict = { + + tab_dict = { re.findall("(.*?)_", tab)[0]: pd.read_hdf( os.path.join(self.table_path, tab), dtype=float ) for tab in self.tables } elif self.table_format == ".csv": - table_dict = {} + for tab in self.tables: head = pd.read_csv(os.path.join(self.table_path, tab), nrows=2) data = pd.read_csv( @@ -104,18 +110,18 @@ class project: ], names=["scorer", "bodyparts", "coords"], ) - table_dict[re.findall("(.*?)_", tab)[0]] = data + tab_dict[re.findall("(.*?)_", tab)[0]] = data lik_dict = defaultdict() - for key, value in table_dict.items(): + for key, value in tab_dict.items(): x = value.xs("x", level="coords", axis=1, drop_level=False) y = value.xs("y", level="coords", axis=1, drop_level=False) lik: pd.DataFrame = value.xs( "likelihood", level="coords", axis=1, drop_level=True ) - table_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1) + tab_dict[key] = pd.concat([x, y], axis=1).sort_index(axis=1) lik_dict[key] = lik if self.smooth_alpha: @@ -123,19 +129,19 @@ class project: if verbose: print("Smoothing trajectories...") - for key, tab in table_dict.items(): + for key, tab in tab_dict.items(): cols = tab.columns smooth = pd.DataFrame( smooth_mult_trajectory(np.array(tab), alpha=self.smooth_alpha) ) smooth.columns = cols - table_dict[key] = smooth + tab_dict[key] = smooth - for key, tab in table_dict.items(): - table_dict[key] = tab[tab.columns.levels[0][0]] + for key, tab in tab_dict.items(): + tab_dict[key] = tab[tab.columns.levels[0][0]] if self.subset_condition: - for key, value in table_dict.items(): + for key, value in tab_dict.items(): lablist = [ b for b in value.columns.levels[0] @@ -152,9 +158,9 @@ class project: tab.columns = tabcols - table_dict[key] = tab + tab_dict[key] = tab - return table_dict, lik_dict + return tab_dict, lik_dict @property def get_scale(self): @@ -182,7 +188,7 @@ class project: return np.array(scales) - def get_distances(self, table_dict, verbose): + def get_distances(self, tab_dict, verbose=False): """Computes the distances between all selected bodyparts over time. If ego is provided, it only returns distances to a specified bodypart""" @@ -191,17 +197,17 @@ class project: nodes = self.distances if nodes == "All": - nodes = table_dict[list(table_dict.keys())[0]].columns.levels[0] + nodes = tab_dict[list(tab_dict.keys())[0]].columns.levels[0] assert [ - i in list(table_dict.values())[0].columns.levels[0] for i in nodes + i in list(tab_dict.values())[0].columns.levels[0] for i in nodes ], "Nodes should correspond to existent bodyparts" scales = self.scales[:, 2:] distance_dict = { key: bpart_distance(tab, scales[i, 1], scales[i, 0],) - for i, (key, tab) in enumerate(table_dict.items()) + for i, (key, tab) in enumerate(tab_dict.items()) } for key in distance_dict.keys(): @@ -217,7 +223,7 @@ class project: return distance_dict - def get_angles(self, table_dict, verbose): + def get_angles(self, tab_dict, verbose): """ Computes all the angles between adjacent bodypart trios per video and per frame in the data. @@ -233,12 +239,11 @@ class project: if verbose: print("Computing angles...") - bp_net = nx.Graph(self.connectivity) - cliques = nx.enumerate_all_cliques(bp_net) + cliques = nx.enumerate_all_cliques(self.connectivity) cliques = [i for i in cliques if len(i) == 3] angle_dict = {} - for key, tab in table_dict.items(): + for key, tab in tab_dict.items(): dats = [] for clique in cliques: diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 2fdb0d26a5668b748c0e4ff224543c12023db260..5c576646d8e422270e66d753e812787d75811392 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -12,10 +12,57 @@ import deepof.preprocess import pytest -@given(table_type=st.integers(min_value=0, max_value=2)) -def test_project_init(table_type): +@settings(deadline=None) +@given( + table_type=st.integers(min_value=0, max_value=2), + arena_type=st.integers(min_value=0, max_value=1), +) +def test_project_init(table_type, arena_type): table_type = [".h5", ".csv", ".foo"][table_type] + arena_type = ["circular", "foo"][arena_type] + + if arena_type == "foo": + with pytest.raises(NotImplementedError): + prun = deepof.preprocess.project( + path=os.path.join(".", "tests", "test_examples"), + arena=arena_type, + arena_dims=[380], + angles=False, + video_format=".mp4", + table_format=table_type, + ) + else: + prun = deepof.preprocess.project( + path=os.path.join(".", "tests", "test_examples"), + arena=arena_type, + arena_dims=[380], + angles=False, + video_format=".mp4", + table_format=table_type, + ) + + if table_type != ".foo" and arena_type != "foo": + + assert type(prun) == deepof.preprocess.project + assert type(prun.load_tables(verbose=True)) == tuple + assert type(prun.get_scale) == np.ndarray + print(prun) + + elif table_type == ".foo" and arena_type != "foo": + with pytest.raises(NotImplementedError): + prun.load_tables(verbose=True) + + +@settings(deadline=None) +@given( + nodes=st.integers(min_value=0, max_value=1), + ego=st.integers(min_value=0, max_value=2), +) +def test_get_distances(nodes, ego): + + nodes = ["All", ["Center", "Nose", "Tail_base"]][nodes] + ego = [False, "Center", "Nose"][ego] prun = deepof.preprocess.project( path=os.path.join(".", "tests", "test_examples"), @@ -23,13 +70,59 @@ def test_project_init(table_type): arena_dims=[380], angles=False, video_format=".mp4", - table_format=table_type, + table_format=".h5", + distances=nodes, + ego=ego, ) - if table_type != ".foo": - assert type(prun) == deepof.preprocess.project - assert type(prun.load_tables(verbose=True)) == tuple - print(prun) - else: - with pytest.raises(NotImplementedError): - prun.load_tables(verbose=True) + prun = prun.get_distances(prun.load_tables()[0], verbose=True) + + assert type(prun) == dict + + +@settings(deadline=None) +@given( + nodes=st.integers(min_value=0, max_value=1), + ego=st.integers(min_value=0, max_value=2), +) +def test_get_angles(nodes, ego): + + nodes = ["All", ["Center", "Nose", "Tail_base"]][nodes] + ego = [False, "Center", "Nose"][ego] + + prun = deepof.preprocess.project( + path=os.path.join(".", "tests", "test_examples"), + arena="circular", + arena_dims=[380], + video_format=".mp4", + table_format=".h5", + distances=nodes, + ego=ego, + ) + + prun = prun.get_angles(prun.load_tables()[0], verbose=True) + + assert type(prun) == dict + + +@settings(deadline=None) +@given( + nodes=st.integers(min_value=0, max_value=1), + ego=st.integers(min_value=0, max_value=2), +) +def test_run(nodes, ego): + + nodes = ["All", ["Center", "Nose", "Tail_base"]][nodes] + ego = [False, "Center", "Nose"][ego] + + prun = deepof.preprocess.project( + path=os.path.join(".", "tests", "test_examples"), + arena="circular", + arena_dims=[380], + video_format=".mp4", + table_format=".h5", + distances=nodes, + ego=ego, + ).run(verbose=True) + + assert type(prun) == deepof.preprocess.coordinates