Commit 9c87263b authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored pose_utils.py

parent 62b7d7fc
Pipeline #83184 passed with stage
in 14 minutes and 18 seconds
......@@ -26,12 +26,12 @@ Coordinates = NewType("Coordinates", Any)
def close_single_contact(
pos_dframe: pd.DataFrame,
left: str,
right: str,
tol: float,
arena_abs: int,
arena_rel: int,
pos_dframe: pd.DataFrame,
left: str,
right: str,
tol: float,
arena_abs: int,
arena_rel: int,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
......@@ -49,22 +49,22 @@ def close_single_contact(
is less than tol, False otherwise"""
close_contact = (
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol
return close_contact
def close_double_contact(
pos_dframe: pd.DataFrame,
left1: str,
left2: str,
right1: str,
right2: str,
tol: float,
arena_abs: int,
arena_rel: int,
rev: bool = False,
pos_dframe: pd.DataFrame,
left1: str,
left2: str,
right1: str,
right2: str,
tol: float,
arena_abs: int,
arena_rel: int,
rev: bool = False,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
......@@ -86,31 +86,31 @@ def close_double_contact(
if rev:
double_contact = (
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
)
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
)
else:
double_contact = (
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
)
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
)
return double_contact
def climb_wall(
arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
) -> np.array:
"""Returns True if the specified mouse is climbing the wall
......@@ -139,12 +139,12 @@ def climb_wall(
def huddle(
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
tol_forward: float,
tol_spine: float,
tol_speed: float,
animal_id: str = "",
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
tol_forward: float,
tol_spine: float,
tol_speed: float,
animal_id: str = "",
) -> np.array:
"""Returns true when the mouse is huddling using simple rules. (!!!) Designed to
work with deepof's default DLC mice models; not guaranteed to work otherwise.
......@@ -166,18 +166,18 @@ def huddle(
animal_id += "_"
forward = (
np.linalg.norm(
pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
axis=1,
)
< tol_forward
) & (
np.linalg.norm(
pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
axis=1,
)
< tol_forward
)
np.linalg.norm(
pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
axis=1,
)
< tol_forward
) & (
np.linalg.norm(
pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
axis=1,
)
< tol_forward
)
spine = [
animal_id + "Spine_1",
......@@ -200,12 +200,12 @@ def huddle(
def following_path(
distance_dframe: pd.DataFrame,
position_dframe: pd.DataFrame,
follower: str,
followed: str,
frames: int = 20,
tol: float = 0,
distance_dframe: pd.DataFrame,
position_dframe: pd.DataFrame,
follower: str,
followed: str,
frames: int = 20,
tol: float = 0,
) -> np.array:
"""For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
followed has walked over the last specified number of frames
......@@ -236,15 +236,15 @@ def following_path(
# Check that the animals are oriented follower's nose -> followed's tail
right_orient1 = (
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[
tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
]
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[
tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
]
)
right_orient2 = (
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
)
follow = np.all(
......@@ -255,13 +255,13 @@ def following_path(
def single_behaviour_analysis(
behaviour_name: str,
treatment_dict: dict,
behavioural_dict: dict,
plot: int = 0,
stat_tests: bool = True,
save: str = None,
ylim: float = None,
behaviour_name: str,
treatment_dict: dict,
behavioural_dict: dict,
plot: int = 0,
stat_tests: bool = True,
save: str = None,
ylim: float = None,
) -> list:
"""Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
with the actual tags, outputs a box plot and a series of significance tests amongst the groups
......@@ -314,9 +314,9 @@ def single_behaviour_analysis(
for i in combinations(treatment_dict.keys(), 2):
# Solves issue with automatically generated examples
if (
beh_dict[i[0]] == beh_dict[i[1]]
or np.var(beh_dict[i[0]]) == 0
or np.var(beh_dict[i[1]]) == 0
beh_dict[i[0]] == beh_dict[i[1]]
or np.var(beh_dict[i[0]]) == 0
or np.var(beh_dict[i[1]]) == 0
):
stat_dict[i] = "Identical sources. Couldn't run"
else:
......@@ -329,7 +329,7 @@ def single_behaviour_analysis(
def max_behaviour(
behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
) -> np.array:
"""Returns the most frequent behaviour in a window of window_size frames
......@@ -374,7 +374,7 @@ def get_hparameters(hparams: dict = {}) -> dict:
"follow_tol": 20,
"huddle_forward": 15,
"huddle_spine": 10,
"huddle_speed": 1,
"huddle_speed": 0.1,
"fps": 24,
}
......@@ -412,13 +412,13 @@ def frame_corners(w, h, corners: dict = {}):
# noinspection PyDefaultArgument,PyProtectedMember
def rule_based_tagging(
tracks: List,
videos: List,
coordinates: Coordinates,
vid_index: int,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
tracks: List,
videos: List,
coordinates: Coordinates,
vid_index: int,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
) -> pd.DataFrame:
"""Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
video displaying the information in real time
......@@ -522,10 +522,10 @@ def rule_based_tagging(
tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
)
> (w / 200 + arena[2])
spatial.distance.cdist(
np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
)
> (w / 200 + arena[2])
).reshape(coords.shape[0]),
index=coords.index,
).astype(bool)
......@@ -548,15 +548,15 @@ def rule_based_tagging(
# noinspection PyProtectedMember,PyDefaultArgument
def rule_based_video(
coordinates: Coordinates,
tracks: List,
videos: List,
vid_index: int,
tag_dict: pd.DataFrame,
frame_limit: int = np.inf,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
coordinates: Coordinates,
tracks: List,
videos: List,
vid_index: int,
tag_dict: pd.DataFrame,
frame_limit: int = np.inf,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
) -> True:
"""Renders a version of the input video with all rule-based taggings in place.
......@@ -623,8 +623,8 @@ def rule_based_video(
# Capture speeds
try:
if (
list(frame_speeds.values())[0] == -np.inf
or fnum % hparams["speed_pause"] == 0
list(frame_speeds.values())[0] == -np.inf
or fnum % hparams["speed_pause"] == 0
):
for _id in animal_ids:
frame_speeds[_id] = speeds[_id + undercond + "Center"][fnum]
......@@ -645,24 +645,32 @@ def rule_based_video(
else:
return corners["downright"]
def conditional_col():
def conditional_col(cond=None):
"""Returns a colour depending on a condition"""
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]:
if cond is None:
cond = frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
if cond:
return 150, 150, 255
else:
return 150, 255, 150
zipped_pos = zip(
animal_ids,
[corners["downleft"], corners["downright"]],
[corners["upleft"], corners["upright"]],
)
if len(animal_ids) > 1:
if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
write_on_frame("Nose-Nose", conditional_pos())
if (
tag_dict[animal_ids[0] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
tag_dict[animal_ids[0] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
write_on_frame("Nose-Tail", corners["downleft"])
if (
tag_dict[animal_ids[1] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
tag_dict[animal_ids[1] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
write_on_frame("Nose-Tail", corners["downright"])
if tag_dict["sidebyside"][fnum]:
......@@ -673,50 +681,35 @@ def rule_based_video(
write_on_frame(
"Side-Rside", conditional_pos(),
)
for _id, down_pos, up_pos in zip(
animal_ids,
[corners["downleft"], corners["downright"]],
[corners["upleft"], corners["upright"]],
):
if tag_dict[_id + "_climbing"][fnum]:
write_on_frame("Climbing", down_pos)
for _id, down_pos, up_pos in zipped_pos:
if (
tag_dict[_id + "_huddle"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
write_on_frame("Huddling", down_pos)
if (
tag_dict[_id + "_following"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
tag_dict[_id + "_following"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
write_on_frame(
"*f",
(int(w * 0.3 / 10), int(h / 10)),
conditional_col(),
"*f", (int(w * 0.3 / 10), int(h / 10)), conditional_col(),
)
write_on_frame(
_id + ": " + str(np.round(frame_speeds[_id], 2)) + " mmpf",
(up_pos[0] - 20, up_pos[1]),
(
(150, 150, 255)
if frame_speeds[_id] == max(list(frame_speeds.values()))
else (150, 255, 150)
),
)
else:
if tag_dict["climbing"][fnum]:
write_on_frame("Climbing", corners["downleft"])
if tag_dict["huddle"][fnum] and not tag_dict["climbing"][fnum]:
write_on_frame("huddle", corners["downleft"])
for _id, down_pos, up_pos in zipped_pos:
if tag_dict[_id + undercond + "climbing"][fnum]:
write_on_frame("Climbing", down_pos)
if (
tag_dict[_id + undercond + "huddle"][fnum]
and not tag_dict[_id + undercond + "climbing"][fnum]
):
write_on_frame("huddle", down_pos)
# Define the condition controlling the colour of the speed display
if len(animal_ids) > 1:
colcond = frame_speeds[_id] == max(list(frame_speeds.values()))
else:
colcond = hparams["huddle_speed"] > frame_speeds
write_on_frame(
str(np.round(frame_speeds, 2)) + " mmpf",
corners["upleft"],
(
(150, 150, 255)
if hparams["huddle_speed"] > frame_speeds
else (150, 255, 150)
),
up_pos,
conditional_col(cond=colcond),
)
if writer is None:
......
......@@ -325,8 +325,8 @@ def test_get_hparameters():
"follow_tol": 20,
"huddle_forward": 15,
"huddle_spine": 10,
"huddle_speed": 1,
"fps": 0,
"huddle_speed": 0.1,
"fps": 24,
}
assert get_hparameters({"speed_pause": 20}) == {
"speed_pause": 20,
......@@ -336,8 +336,8 @@ def test_get_hparameters():
"follow_tol": 20,
"huddle_forward": 15,
"huddle_spine": 10,
"huddle_speed": 1,
"fps": 0,
"huddle_speed": 0.1,
"fps": 24,
}
......@@ -400,6 +400,5 @@ def test_rule_based_video():
vid_index=0,
frame_limit=100,
tag_dict=hardcoded_tags,
mode="save",
path=os.path.join(".", "tests", "test_examples", "Videos"),
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment