Commit 6f74d1f7 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added first tests for rule_based_tagging

parent 2158e4b7
......@@ -163,7 +163,7 @@ class project:
vid_index,
path=self.video_path,
arena_type=self.arena,
)
)[0]
* 2
)
+ self.arena_dims
......@@ -376,7 +376,7 @@ class coordinates:
cols = tab.columns.levels[0]
except AttributeError:
cols = tab.columns
vel = rolling_speed(tab, typ="coords", order=order + 1)
vel = rolling_speed(tab, deriv=order + 1)
vel.columns = cols
tabs[key] = vel
......@@ -422,7 +422,7 @@ class coordinates:
cols = tab.columns.levels[0]
except AttributeError:
cols = tab.columns
vel = rolling_speed(tab, typ="dists", order=order + 1)
vel = rolling_speed(tab, deriv=order + 1)
vel.columns = cols
tabs[key] = vel
......@@ -453,7 +453,7 @@ class coordinates:
cols = tab.columns.levels[0]
except AttributeError:
cols = tab.columns
vel = rolling_speed(tab, typ="dists", order=order + 1)
vel = rolling_speed(tab, deriv=order + 1)
vel.columns = cols
tabs[key] = vel
......
......@@ -21,7 +21,7 @@ from typing import Tuple, Any, List, Union, Dict, NewType
# DEFINE CUSTOM ANNOTATED TYPES #
TableDict = NewType("TableDict", Any)
Coordinates = NewType("Coordinates", Any)
# QUALITY CONTROL AND PREPROCESSING #
......@@ -408,7 +408,7 @@ def recognize_arena(
# Detect arena and extract positions
arena = circular_arena_recognition(frame)[0]
if h is not None and w is not None:
if h is None and w is None:
h, w = frame.shape[0], frame.shape[1]
fnum += 1
......@@ -551,7 +551,7 @@ def huddle(
< tol_forward
)
spine = ["Spine1", "Center", "Spine2", "Tail_base"]
spine = ["Spine_1", "Center", "Spine_2", "Tail_base"]
spine_dists = []
for comb in range(2):
spine_dists.append(
......@@ -871,7 +871,7 @@ def cluster_transition_matrix(
def rule_based_tagging(
tracks: List,
videos: List,
table_dict: TableDict,
coordinates: Coordinates,
vid_index: int,
arena_abs: int,
animal_ids: List = None,
......@@ -884,18 +884,13 @@ def rule_based_tagging(
path: str = os.path.join("./"),
arena_type: str = "circular",
classifiers: Dict = None,
) -> Tuple[pd.DataFrame, Any]:
"""Outputs a dataframe with the motives registered per frame. If mp4==True, outputs a video in mp4 format"""
# noinspection PyProtectedMember
assert table_dict._type == "merged", (
"Table_dict must be of merged type, "
"and contain at least position, speed and distance information"
)
) -> pd.DataFrame:
"""Outputs a dataframe with the motives registered per frame."""
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
dframe = table_dict[vid_name]
distances = coordinates.get_coords()[vid_name]
speeds = coordinates.get_coords(speed=1)[vid_name]
arena, h, w = recognize_arena(videos, vid_index, path, recog_limit, arena_type)
# Dictionary with motives per frame
......@@ -913,15 +908,15 @@ def rule_based_tagging(
behavioural_tags.append(_id + behaviour)
else:
behavioural_tags.append(["huddle", "climbing", "speed"])
behavioural_tags += ["huddle", "climbing", "speed"]
tag_dict = {tag: np.zeros(dframe.shape[0]) for tag in behavioural_tags}
tag_dict = {tag: np.zeros(distances.shape[0]) for tag in behavioural_tags}
if animal_ids:
# Define behaviours that can be computed on the fly from the distance matrix
tag_dict["nose2nose"] = smooth_boolean_array(
close_single_contact(
dframe,
distances,
animal_ids[0] + "_Nose",
animal_ids[1] + "_Nose",
15.0,
......@@ -931,7 +926,7 @@ def rule_based_tagging(
)
tag_dict[animal_ids[0] + "_nose2tail"] = smooth_boolean_array(
close_single_contact(
dframe,
distances,
animal_ids[0] + "_Nose",
animal_ids[1] + "_Tail_base",
15.0,
......@@ -941,7 +936,7 @@ def rule_based_tagging(
)
tag_dict[animal_ids[1] + "_nose2tail"] = smooth_boolean_array(
close_single_contact(
dframe,
distances,
animal_ids[1] + "_Nose",
animal_ids[0] + "_Tail_base",
15.0,
......@@ -951,7 +946,7 @@ def rule_based_tagging(
)
tag_dict["sidebyside"] = smooth_boolean_array(
close_double_contact(
dframe,
distances,
animal_ids[0] + "_Nose",
animal_ids[0] + "_Tail_base",
animal_ids[1] + "_Nose",
......@@ -964,7 +959,7 @@ def rule_based_tagging(
)
tag_dict["sidereside"] = smooth_boolean_array(
close_double_contact(
dframe,
distances,
animal_ids[0] + "_Nose",
animal_ids[0] + "_Tail_base",
animal_ids[1] + "_Nose",
......@@ -978,8 +973,8 @@ def rule_based_tagging(
for _id in animal_ids:
tag_dict[_id + "_following"] = smooth_boolean_array(
following_path(
dframe[vid_name],
dframe,
distances[vid_name],
distances,
follower=_id,
followed=[i for i in animal_ids if i != _id][0],
frames=20,
......@@ -990,38 +985,37 @@ def rule_based_tagging(
pd.Series(
(
spatial.distance.cdist(
np.array(dframe[_id + "_Nose"]), np.array([arena[:2]])
np.array(distances[_id + "_Nose"]), np.array([arena[:2]])
)
> (w / 200 + arena[2])
).reshape(dframe.shape[0]),
index=dframe.index,
).reshape(distances.shape[0]),
index=distances.index,
)
)
tag_dict[_id + "speed"] = rolling_speed(
dframe[_id + "_Center"], window=speed_pause
)
tag_dict[_id + "_speed"] = speeds[_id + "_speed"]
else:
print(w)
tag_dict["climbwall"] = smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(dframe["Nose"]), np.array([arena[:2]])
np.array(distances["Nose"]), np.array([arena[:2]])
)
> (w / 200 + arena[2])
).reshape(dframe.shape[0]),
index=dframe.index,
).reshape(distances.shape[0]),
index=distances.index,
)
)
tag_dict["speed"] = rolling_speed(dframe["Center"], window=speed_pause)
tag_dict["speed"] = speeds["Center"]
if "huddle" in classifiers:
if classifiers and "huddle" in classifiers:
mouse_X = {
_id: np.array(
dframe[vid_name][
distances[vid_name][
[
j
for j in dframe[vid_name].keys()
for j in distances[vid_name].keys()
if (len(j) == 2 and _id in j[0] and _id in j[1])
]
]
......@@ -1036,10 +1030,10 @@ def rule_based_tagging(
try:
for _id in animal_ids:
tag_dict[_id + "_huddle"] = smooth_boolean_array(
huddle(dframe, 25, 25, 5)
huddle(distances, 25, 25, 5)
)
except TypeError:
tag_dict["huddle"] = smooth_boolean_array(huddle(dframe, 25, 25, 5))
tag_dict["huddle"] = smooth_boolean_array(huddle(distances, 25, 25, 5))
# if any([show, save]):
# cap = cv2.VideoCapture(path + videos[vid_index])
......@@ -1201,9 +1195,9 @@ def rule_based_tagging(
# cap.release()
# cv2.destroyAllWindows()
tagdf = pd.DataFrame(tag_dict)
tag_df = pd.DataFrame(tag_dict)
return tagdf, arena
return tag_df
# TODO:
......@@ -1211,3 +1205,7 @@ def rule_based_tagging(
# - Add digging to rule_based_tagging
# - Add center to rule_based_tagging
# - Check for features requested by Joeri
# - Check speed. Avoid recomputing unnecessarily
# - Pass thresholds as parameters of the function. Provide defaults (we should tune them in the future)
# - Check if attributes I'm asking for (eg arena) are already stored in Table_dict metadata
......@@ -84,7 +84,7 @@ def model_comparison_plot(
bars.append(
spl.bar(
xpos,
m_bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
m_bic[i * len(n_components_range): (i + 1) * len(n_components_range)],
color=color,
width=0.2,
)
......
......@@ -376,7 +376,7 @@ def test_close_double_contact(pos_dframe, tol, rev):
)
pos_dframe.columns = idx
close_contact = close_double_contact(
pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, rev, 1, 1
pos_dframe, "bpart1", "bpart2", "bpart3", "bpart4", tol, 1, 1, rev
)
assert close_contact.dtype == bool
assert np.array(close_contact).shape[0] <= pos_dframe.shape[0]
......@@ -386,14 +386,21 @@ def test_close_double_contact(pos_dframe, tol, rev):
@given(indexes=st.data())
def test_recognize_arena_and_subfunctions(indexes):
path = "./tests/test_examples/Videos/"
path = os.path.join(".", "tests", "test_examples", "Videos")
videos = [i for i in os.listdir(path) if i.endswith("mp4")]
vid_index = indexes.draw(st.integers(min_value=0, max_value=len(videos) - 1))
recoglimit = indexes.draw(st.integers(min_value=1, max_value=10))
assert recognize_arena(videos, vid_index, path, recoglimit, "") == 0
assert recognize_arena(videos, vid_index, path, recoglimit, "")[0] == 0
assert len(recognize_arena(videos, vid_index, path, recoglimit, "circular")) == 3
assert len(recognize_arena(videos, vid_index, path, recoglimit, "circular")[0]) == 3
assert (
type(recognize_arena(videos, vid_index, path, recoglimit, "circular")[1]) == int
)
assert (
type(recognize_arena(videos, vid_index, path, recoglimit, "circular")[2]) == int
)
@settings(deadline=None)
......@@ -410,9 +417,9 @@ def test_climb_wall(arena, tol):
prun = (
deepof.preprocess.project(
path="./tests/test_examples",
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[arena[0]],
arena_dims=[arena[2]],
angles=False,
video_format=".mp4",
table_format=".h5",
......@@ -495,8 +502,9 @@ def test_rolling_speed(dframe, sampler):
),
tol_forward=st.floats(min_value=0.01, max_value=4.98),
tol_spine=st.floats(min_value=0.01, max_value=4.98),
tol_speed=st.floats(min_value=0.01, max_value=4.98),
)
def test_huddle(pos_dframe, tol_forward, tol_spine):
def test_huddle(pos_dframe, tol_forward, tol_spine, tol_speed):
idx = pd.MultiIndex.from_product(
[
......@@ -505,9 +513,9 @@ def test_huddle(pos_dframe, tol_forward, tol_spine):
"Right_ear",
"Left_fhip",
"Right_fhip",
"Spine1",
"Spine_1",
"Center",
"Spine2",
"Spine_2",
"Tail_base",
],
["X", "y"],
......@@ -515,7 +523,7 @@ def test_huddle(pos_dframe, tol_forward, tol_spine):
names=["bodyparts", "coords"],
)
pos_dframe.columns = idx
hudd = huddle(pos_dframe, tol_forward, tol_spine)
hudd = huddle(pos_dframe, tol_forward, tol_spine, tol_speed)
assert hudd.dtype == bool
assert np.array(hudd).shape[0] == pos_dframe.shape[0]
......@@ -734,6 +742,26 @@ def test_cluster_transition_matrix(sampler, autocorrelation, return_graph):
@settings(deadline=None)
@given()
def test_rule_based_tagging():
pass
\ No newline at end of file
@given(sampler=st.data())
def test_rule_based_tagging(sampler):
prun = deepof.preprocess.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=[380],
angles=False,
video_format=".mp4",
table_format=".h5",
).run(verbose=False)
hardcoded_tags = rule_based_tagging(
list([i + "_" for i in prun.get_coords().keys()]),
["test_video_circular_arena.mp4"],
prun,
vid_index=0,
arena_abs=380,
path=os.path.join(".", "tests", "test_examples", "Videos"),
)
assert type(hardcoded_tags) == pd.DataFrame
assert hardcoded_tags.shape[1] == 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment