Commit b91bf8c1 authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored pose_utils.py and pose_utils.py

parent f133d5a3
Pipeline #83093 passed with stage
in 14 minutes and 38 seconds
......@@ -360,12 +360,12 @@ def max_behaviour(
def get_hparameters(hparams: dict = {}) -> dict:
"""Returns the most frequent behaviour in a window of window_size frames
Parameters:
- hparams (dict): dictionary containing hyperparameters to overwrite
Parameters:
- hparams (dict): dictionary containing hyperparameters to overwrite
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
defaults = {
"speed_pause": 10,
......@@ -384,6 +384,32 @@ def get_hparameters(hparams: dict = {}) -> dict:
return defaults
# noinspection PyDefaultArgument
def frame_corners(w, h, corners: dict = {}):
"""Returns a dictionary with the corner positions of the video frame
Parameters:
- w (int): width of the frame in pixels
- h (int): height of the frame in pixels
- corners (dict): dictionary containing corners to overwrite
Returns:
- defaults (dict): dictionary with overwriten parameters. Those not
specified in the input retain their default values"""
defaults = {
"downleft": (int(w * 0.3 / 10), int(h / 1.05)),
"downright": (int(w * 6.5 / 10), int(h / 1.05)),
"upleft": (int(w * 0.3 / 10), int(h / 20)),
"upright": (int(w * 6.3 / 10), int(h / 20)),
}
for k, v in corners.items():
defaults[k] = v
return defaults
# noinspection PyDefaultArgument,PyProtectedMember
def rule_based_tagging(
tracks: List,
......@@ -520,6 +546,7 @@ def rule_based_tagging(
return tag_df
# noinspection PyProtectedMember
def rule_based_video(
coordinates,
tracks,
......@@ -573,6 +600,7 @@ def rule_based_video(
arena, h, w = deepof.utils.recognize_arena(
videos, vid_index, path, recog_limit, coordinates._arena
)
corners = frame_corners(h, w)
if mode in ["show", "save"]:
......@@ -594,12 +622,6 @@ def rule_based_video(
font = cv2.FONT_HERSHEY_COMPLEX_SMALL
# Label positions
downleft = (int(w * 0.3 / 10), int(h / 1.05))
downright = (int(w * 6.5 / 10), int(h / 1.05))
upleft = (int(w * 0.3 / 10), int(h / 20))
upright = (int(w * 6.3 / 10), int(h / 20))
# Capture speeds
try:
if (
......@@ -619,9 +641,9 @@ def rule_based_video(
frame,
"Nose-Nose",
(
downleft
corners["downleft"]
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
else downright
else corners["downright"]
),
font,
1,
......@@ -633,23 +655,35 @@ def rule_based_video(
and not tag_dict["sidereside"][fnum]
):
cv2.putText(
frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2
frame,
"Nose-Tail",
corners["downleft"],
font,
1,
(255, 255, 255),
2,
)
if (
tag_dict[animal_ids[1] + "_nose2tail"][fnum]
and not tag_dict["sidereside"][fnum]
):
cv2.putText(
frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2
frame,
"Nose-Tail",
corners["downright"],
font,
1,
(255, 255, 255),
2,
)
if tag_dict["sidebyside"][fnum]:
cv2.putText(
frame,
"Side-side",
(
downleft
corners["downleft"]
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
else downright
else corners["downright"]
),
font,
1,
......@@ -661,9 +695,9 @@ def rule_based_video(
frame,
"Side-Rside",
(
downleft
corners["downleft"]
if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]]
else downright
else corners["downright"]
),
font,
1,
......@@ -671,7 +705,9 @@ def rule_based_video(
2,
)
for _id, down_pos, up_pos in zip(
animal_ids, [downleft, downright], [upleft, upright]
animal_ids,
[corners["downleft"], corners["downright"]],
[corners["upleft"], corners["upright"]],
):
if tag_dict[_id + "_climbing"][fnum]:
cv2.putText(
......@@ -719,14 +755,28 @@ def rule_based_video(
else:
if tag_dict["climbing"][fnum]:
cv2.putText(
frame, "Climbing", downleft, font, 1, (255, 255, 255), 2
frame,
"Climbing",
corners["downleft"],
font,
1,
(255, 255, 255),
2,
)
if tag_dict["huddle"][fnum] and not tag_dict["climbing"][fnum]:
cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2)
cv2.putText(
frame,
"huddle",
corners["downleft"],
font,
1,
(255, 255, 255),
2,
)
cv2.putText(
frame,
str(np.round(frame_speeds, 2)) + " mmpf",
upleft,
corners["upleft"],
font,
1,
(
......
......@@ -339,6 +339,16 @@ def test_get_hparameters():
}
@settings(deadline=None)
@given(
w=st.integers(min_value=300, max_value=500),
h=st.integers(min_value=300, max_value=500),
)
def test_frame_corners(w, h):
assert len(frame_corners(w, h)) == 4
assert frame_corners(w, h, {"downright": "test"})["downright"] == "test"
def test_rule_based_tagging():
prun = deepof.preprocess.project(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment