Commit 6b10609e authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored single_behaviour_analysis and added tests

parent 7d43c995
......@@ -8,12 +8,12 @@ import numpy as np
import os
import pandas as pd
import regex as re
import scipy
import seaborn as sns
from copy import deepcopy
from itertools import cycle, combinations, product
from joblib import Parallel, delayed
from scipy import spatial
from scipy import stats
from sklearn import mixture
from tqdm import tqdm_notebook as tqdm
......@@ -586,16 +586,29 @@ def following_path(
def single_behaviour_analysis(
behaviour_name,
treatment_dict,
behavioural_dict,
plot=False,
stats=False,
save=False,
ylim=False,
):
behaviour_name: str,
treatment_dict: dict,
behavioural_dict: dict,
plot: int = 0,
stat_tests: bool = True,
save: str = None,
ylim: float = None,
) -> list:
"""Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
with the actual taggings, outputs a box plot and a series of significance tests amongst the groups"""
with the actual tags, outputs a box plot and a series of significance tests amongst the groups
Parameters:
- behaviour_name (str): name of the behavioural trait to analize
- treatment_dict (dict): dictionary containing video names as keys and experimental conditions as values
- behavioural_dict (dict): tagged dictionary containing video names as keys and annotations as values
- plot (int): Silent if 0; otherwise, indicates the dpi of the figure to plot
- stat_tests (bool): performs FDR corrected Mann-U non-parametric tests among all groups if True
- save (str): Saves the produced figure to the specified file
- ylim (float): y-limit for the boxplot. Ignored if plot == False
Returns:
- beh_dict (dict): dictionary containing experimental conditions as keys and video names as values
- stat_dict (dict): dictionary containing condition pairs as keys and stat results as values"""
beh_dict = {condition: [] for condition in treatment_dict.keys()}
......@@ -606,34 +619,40 @@ def single_behaviour_analysis(
/ len(behavioural_dict[ind][behaviour_name])
)
if plot:
sns.boxplot(list(beh_dict.keys()), list(beh_dict.values()), orient="vertical")
return_list = [beh_dict]
plt.title("{} across groups".format(behaviour_name))
plt.ylabel("Proportion of frames")
if plot > 0:
if ylim != False:
plt.ylim(*ylim)
fig, ax = plt.subplots(dpi=plot)
plt.tight_layout()
plt.savefig("Exploration_heatmaps.pdf")
sns.boxplot(
list(beh_dict.keys()), list(beh_dict.values()), orient="vertical", ax=ax
)
ax.set_title("{} across groups".format(behaviour_name))
ax.set_ylabel("Proportion of frames")
if ylim is not None:
ax.set_ylim(ylim)
if save != False:
if save is not None:
plt.savefig(save)
plt.show()
return_list.append(ax)
if stats:
if stat_tests:
stat_dict = {}
for i in combinations(treatment_dict.keys(), 2):
print(i)
print(scipy.stats.mannwhitneyu(beh_dict[i[0]], beh_dict[i[1]]))
stat_dict[i] = stats.mannwhitneyu(beh_dict[i[0]], beh_dict[i[1]])
return_list.append(stat_dict)
return beh_dict
return return_list
##### MAIN BEHAVIOUR TAGGING FUNCTION #####
# MAIN BEHAVIOUR TAGGING FUNCTION #
def Tag_video(
def tag_video(
Tracks,
Videos,
Track_dict,
......@@ -1140,3 +1159,7 @@ def model_comparison_plot(
plt.savefig(save)
plt.show()
# TODO:
# - Add sequence plot to single_behaviour_analysis (show how the condition varies across a specified time window)
# @author lucasmiranda42
from hypothesis import given
from hypothesis import HealthCheck
from hypothesis import settings
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
......@@ -450,7 +451,6 @@ def test_rolling_speed(dframe, sampler):
order1 = sampler.draw(st.integers(min_value=1, max_value=3))
order2 = sampler.draw(st.integers(min_value=order1, max_value=3))
window2 = sampler.draw(st.integers(min_value=10, max_value=25))
idx = pd.MultiIndex.from_product(
[["bpart1", "bpart2"], ["X", "y"]], names=["bodyparts", "coords"],
......@@ -459,7 +459,6 @@ def test_rolling_speed(dframe, sampler):
speeds1 = rolling_speed(dframe, 5, 10, order1)
speeds2 = rolling_speed(dframe, 5, 10, order2)
speeds3 = rolling_speed(dframe, window2, 10, order1)
assert speeds1.shape[0] == dframe.shape[0]
assert speeds1.shape[1] == dframe.shape[1] // 2
......@@ -527,7 +526,7 @@ def test_huddle(pos_dframe, tol_forward, tol_spine):
distance_dframe=data_frames(
index=range_indexes(min_size=20, max_size=20),
columns=columns(
["d1", "d2", "d3", "d4",],
["d1", "d2", "d3", "d4"],
dtype=float,
elements=st.floats(min_value=-20, max_value=20),
),
......@@ -535,7 +534,7 @@ def test_huddle(pos_dframe, tol_forward, tol_spine):
position_dframe=data_frames(
index=range_indexes(min_size=20, max_size=20),
columns=columns(
["X1", "y1", "X2", "y2", "X3", "y3", "X4", "y4",],
["X1", "y1", "X2", "y2", "X3", "y3", "X4", "y4"],
dtype=float,
elements=st.floats(min_value=-20, max_value=20),
),
......@@ -553,7 +552,7 @@ def test_following_path(distance_dframe, position_dframe, frames, tol):
]
pos_idx = pd.MultiIndex.from_product(
[bparts, ["X", "y"],], names=["bodyparts", "coords"],
[bparts, ["X", "y"]], names=["bodyparts", "coords"],
)
position_dframe.columns = pos_idx
......@@ -573,3 +572,52 @@ def test_following_path(distance_dframe, position_dframe, frames, tol):
assert len(follow) == distance_dframe.shape[0]
assert np.sum(follow) <= position_dframe.shape[0]
assert np.sum(follow) <= distance_dframe.shape[0]
@settings(
deadline=None, suppress_health_check=[HealthCheck.too_slow],
)
@given(sampler=st.data())
def test_single_behaviour_analysis(sampler):
behaviours = sampler.draw(
st.lists(min_size=2, elements=st.text(min_size=5), unique=True)
)
treatments = sampler.draw(
st.lists(min_size=2, max_size=4, elements=st.text(min_size=5), unique=True)
)
behavioural_dict = sampler.draw(
st.dictionaries(
min_size=2,
keys=st.text(min_size=5),
values=data_frames(
index=range_indexes(min_size=50, max_size=50),
columns=columns(behaviours, dtype=bool),
).map(
lambda x: 0 * x + np.array(np.random.randint(0, 2, x.shape), dtype=bool)
),
)
)
ind_dict = {vid: np.random.choice(treatments) for vid in behavioural_dict.keys()}
treatment_dict = {treat: [] for treat in set(ind_dict.values())}
for vid, treat in ind_dict.items():
treatment_dict[treat].append(vid)
ylim = sampler.draw(st.floats(min_value=0, max_value=10))
stat_tests = sampler.draw(st.booleans())
out = single_behaviour_analysis(
behaviours[0],
treatment_dict,
behavioural_dict,
plot=0,
stat_tests=stat_tests,
save=None,
ylim=ylim,
)
assert len(out) == 1 if stat_tests == 0 else len(out) == 2
assert type(out[0]) == dict
if stat_tests:
assert type(out[0]) == dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment