Commit 62b7d7fc authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored pose_utils.py

parent 9b9660fe
Pipeline #83182 failed with stage
in 21 minutes and 2 seconds
......@@ -22,17 +22,16 @@ from scipy import stats
from tqdm import tqdm
from typing import Any, List, NewType
Coordinates = NewType("Coordinates", Any)
def close_single_contact(
pos_dframe: pd.DataFrame,
left: str,
right: str,
tol: float,
arena_abs: int,
arena_rel: int,
pos_dframe: pd.DataFrame,
left: str,
right: str,
tol: float,
arena_abs: int,
arena_rel: int,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
......@@ -50,22 +49,22 @@ def close_single_contact(
is less than tol, False otherwise"""
close_contact = (
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol
np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
) / arena_rel < tol
return close_contact
def close_double_contact(
pos_dframe: pd.DataFrame,
left1: str,
left2: str,
right1: str,
right2: str,
tol: float,
arena_abs: int,
arena_rel: int,
rev: bool = False,
pos_dframe: pd.DataFrame,
left1: str,
left2: str,
right1: str,
right2: str,
tol: float,
arena_abs: int,
arena_rel: int,
rev: bool = False,
) -> np.array:
"""Returns a boolean array that's True if the specified body parts are closer than tol.
......@@ -87,31 +86,31 @@ def close_double_contact(
if rev:
double_contact = (
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
)
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
)
else:
double_contact = (
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
)
(np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
/ arena_rel
< tol
) & (
(np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
/ arena_rel
< tol
)
return double_contact
def climb_wall(
arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
) -> np.array:
"""Returns True if the specified mouse is climbing the wall
......@@ -140,12 +139,12 @@ def climb_wall(
def huddle(
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
tol_forward: float,
tol_spine: float,
tol_speed: float,
animal_id: str = "",
pos_dframe: pd.DataFrame,
speed_dframe: pd.DataFrame,
tol_forward: float,
tol_spine: float,
tol_speed: float,
animal_id: str = "",
) -> np.array:
"""Returns true when the mouse is huddling using simple rules. (!!!) Designed to
work with deepof's default DLC mice models; not guaranteed to work otherwise.
......@@ -167,18 +166,18 @@ def huddle(
animal_id += "_"
forward = (
np.linalg.norm(
pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
axis=1,
)
< tol_forward
) & (
np.linalg.norm(
pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
axis=1,
)
< tol_forward
)
np.linalg.norm(
pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
axis=1,
)
< tol_forward
) & (
np.linalg.norm(
pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
axis=1,
)
< tol_forward
)
spine = [
animal_id + "Spine_1",
......@@ -201,12 +200,12 @@ def huddle(
def following_path(
distance_dframe: pd.DataFrame,
position_dframe: pd.DataFrame,
follower: str,
followed: str,
frames: int = 20,
tol: float = 0,
distance_dframe: pd.DataFrame,
position_dframe: pd.DataFrame,
follower: str,
followed: str,
frames: int = 20,
tol: float = 0,
) -> np.array:
"""For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
followed has walked over the last specified number of frames
......@@ -237,15 +236,15 @@ def following_path(
# Check that the animals are oriented follower's nose -> followed's tail
right_orient1 = (
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[
tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
]
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[
tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
]
)
right_orient2 = (
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
< distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
)
follow = np.all(
......@@ -256,13 +255,13 @@ def following_path(
def single_behaviour_analysis(
behaviour_name: str,
treatment_dict: dict,
behavioural_dict: dict,
plot: int = 0,
stat_tests: bool = True,
save: str = None,
ylim: float = None,
behaviour_name: str,
treatment_dict: dict,
behavioural_dict: dict,
plot: int = 0,
stat_tests: bool = True,
save: str = None,
ylim: float = None,
) -> list:
"""Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
with the actual tags, outputs a box plot and a series of significance tests amongst the groups
......@@ -315,9 +314,9 @@ def single_behaviour_analysis(
for i in combinations(treatment_dict.keys(), 2):
# Solves issue with automatically generated examples
if (
beh_dict[i[0]] == beh_dict[i[1]]
or np.var(beh_dict[i[0]]) == 0
or np.var(beh_dict[i[1]]) == 0
beh_dict[i[0]] == beh_dict[i[1]]
or np.var(beh_dict[i[0]]) == 0
or np.var(beh_dict[i[1]]) == 0
):
stat_dict[i] = "Identical sources. Couldn't run"
else:
......@@ -330,7 +329,7 @@ def single_behaviour_analysis(
def max_behaviour(
behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
) -> np.array:
"""Returns the most frequent behaviour in a window of window_size frames
......@@ -413,13 +412,13 @@ def frame_corners(w, h, corners: dict = {}):
# noinspection PyDefaultArgument,PyProtectedMember
def rule_based_tagging(
tracks: List,
videos: List,
coordinates: Coordinates,
vid_index: int,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
tracks: List,
videos: List,
coordinates: Coordinates,
vid_index: int,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
) -> pd.DataFrame:
"""Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
video displaying the information in real time
......@@ -523,10 +522,10 @@ def rule_based_tagging(
tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
pd.Series(
(
spatial.distance.cdist(
np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
)
> (w / 200 + arena[2])
spatial.distance.cdist(
np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
)
> (w / 200 + arena[2])
).reshape(coords.shape[0]),
index=coords.index,
).astype(bool)
......@@ -549,16 +548,15 @@ def rule_based_tagging(
# noinspection PyProtectedMember,PyDefaultArgument
def rule_based_video(
coordinates: Coordinates,
tracks: List,
videos: List,
vid_index: int,
tag_dict: pd.DataFrame,
mode: str,
frame_limit: int = np.inf,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
coordinates: Coordinates,
tracks: List,
videos: List,
vid_index: int,
tag_dict: pd.DataFrame,
frame_limit: int = np.inf,
recog_limit: int = 1,
path: str = os.path.join("."),
hparams: dict = {},
) -> True:
"""Renders a version of the input video with all rule-based taggings in place.
......@@ -567,8 +565,6 @@ def rule_based_video(
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
- vid_index (int): index in videos of the experiment to annotate
- mode (str): if show, enables the display of the annotated video in a separate window, saves to mp4 file
if save
- fps (float): frames per second of the analysed video. Same as input by default
- path (str): directory in which the experimental data is stored
- frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
......@@ -591,11 +587,6 @@ def rule_based_video(
"""
# DATA OBTENTION AND PREPARATION
assert mode in [
"save",
"show",
], "Parameter 'mode' should be one of 'save' and 'show'. See docs for details"
hparams = get_hparameters(hparams)
animal_ids = coordinates._animal_ids
undercond = "_" if len(animal_ids) > 1 else ""
......@@ -632,8 +623,8 @@ def rule_based_video(
# Capture speeds
try:
if (
list(frame_speeds.values())[0] == -np.inf
or fnum % hparams["speed_pause"] == 0
list(frame_speeds.values())[0] == -np.inf
or fnum % hparams["speed_pause"] == 0
):
for _id in animal_ids:
frame_speeds[_id] = speeds[_id + undercond + "Center"][fnum]
......@@ -665,13 +656,13 @@ def rule_based_video(
if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
write_on_frame("Nose-Nose", conditional_pos())
if (
tag_dict[animal_ids[0] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
tag_dict[animal_ids[0] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
write_on_frame("Nose-Tail", corners["downleft"])
if (
tag_dict[animal_ids[1] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
tag_dict[animal_ids[1] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
write_on_frame("Nose-Tail", corners["downright"])
if tag_dict["sidebyside"][fnum]:
......@@ -683,20 +674,20 @@ def rule_based_video(
"Side-Rside", conditional_pos(),
)
for _id, down_pos, up_pos in zip(
animal_ids,
[corners["downleft"], corners["downright"]],
[corners["upleft"], corners["upright"]],
animal_ids,
[corners["downleft"], corners["downright"]],
[corners["upleft"], corners["upright"]],
):
if tag_dict[_id + "_climbing"][fnum]:
write_on_frame("Climbing", down_pos)
if (
tag_dict[_id + "_huddle"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
tag_dict[_id + "_huddle"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
write_on_frame("Huddling", down_pos)
if (
tag_dict[_id + "_following"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
tag_dict[_id + "_following"][fnum]
and not tag_dict[_id + "_climbing"][fnum]
):
write_on_frame(
"*f",
......@@ -728,27 +719,19 @@ def rule_based_video(
),
)
if mode == "show": # pragma: no cover
cv2.imshow("frame", frame)
if cv2.waitKey(1) == ord("q"):
break
if mode == "save":
if writer is None:
# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# Define the FPS. Also frame size is passed.
writer = cv2.VideoWriter()
writer.open(
re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
cv2.VideoWriter_fourcc(*"MJPG"),
hparams["fps"],
(frame.shape[1], frame.shape[0]),
True,
)
if writer is None:
# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
# Define the FPS. Also frame size is passed.
writer = cv2.VideoWriter()
writer.open(
re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
cv2.VideoWriter_fourcc(*"MJPG"),
hparams["fps"],
(frame.shape[1], frame.shape[0]),
True,
)
writer.write(frame)
writer.write(frame)
pbar.update(1)
fnum += 1
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment