diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py index 0caf31a230c44b56e4e6b6fa9d63587dee069c28..a8e7aba0a43dd22ceed7aa999db822dc30bc9467 100644 --- a/deepof/pose_utils.py +++ b/deepof/pose_utils.py @@ -360,12 +360,12 @@ def max_behaviour( def get_hparameters(hparams: dict = {}) -> dict: """Returns the most frequent behaviour in a window of window_size frames - Parameters: - - hparams (dict): dictionary containing hyperparameters to overwrite + Parameters: + - hparams (dict): dictionary containing hyperparameters to overwrite - Returns: - - defaults (dict): dictionary with overwriten parameters. Those not - specified in the input retain their default values""" + Returns: + - defaults (dict): dictionary with overwriten parameters. Those not + specified in the input retain their default values""" defaults = { "speed_pause": 10, @@ -384,6 +384,32 @@ def get_hparameters(hparams: dict = {}) -> dict: return defaults +# noinspection PyDefaultArgument +def frame_corners(w, h, corners: dict = {}): + """Returns a dictionary with the corner positions of the video frame + + Parameters: + - w (int): width of the frame in pixels + - h (int): height of the frame in pixels + - corners (dict): dictionary containing corners to overwrite + + Returns: + - defaults (dict): dictionary with overwriten parameters. Those not + specified in the input retain their default values""" + + defaults = { + "downleft": (int(w * 0.3 / 10), int(h / 1.05)), + "downright": (int(w * 6.5 / 10), int(h / 1.05)), + "upleft": (int(w * 0.3 / 10), int(h / 20)), + "upright": (int(w * 6.3 / 10), int(h / 20)), + } + + for k, v in corners.items(): + defaults[k] = v + + return defaults + + # noinspection PyDefaultArgument,PyProtectedMember def rule_based_tagging( tracks: List, @@ -520,6 +546,7 @@ def rule_based_tagging( return tag_df +# noinspection PyProtectedMember def rule_based_video( coordinates, tracks, @@ -573,6 +600,7 @@ def rule_based_video( arena, h, w = deepof.utils.recognize_arena( videos, vid_index, path, recog_limit, coordinates._arena ) + corners = frame_corners(h, w) if mode in ["show", "save"]: @@ -594,12 +622,6 @@ def rule_based_video( font = cv2.FONT_HERSHEY_COMPLEX_SMALL - # Label positions - downleft = (int(w * 0.3 / 10), int(h / 1.05)) - downright = (int(w * 6.5 / 10), int(h / 1.05)) - upleft = (int(w * 0.3 / 10), int(h / 20)) - upright = (int(w * 6.3 / 10), int(h / 20)) - # Capture speeds try: if ( @@ -619,9 +641,9 @@ def rule_based_video( frame, "Nose-Nose", ( - downleft + corners["downleft"] if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]] - else downright + else corners["downright"] ), font, 1, @@ -633,23 +655,35 @@ def rule_based_video( and not tag_dict["sidereside"][fnum] ): cv2.putText( - frame, "Nose-Tail", downleft, font, 1, (255, 255, 255), 2 + frame, + "Nose-Tail", + corners["downleft"], + font, + 1, + (255, 255, 255), + 2, ) if ( tag_dict[animal_ids[1] + "_nose2tail"][fnum] and not tag_dict["sidereside"][fnum] ): cv2.putText( - frame, "Nose-Tail", downright, font, 1, (255, 255, 255), 2 + frame, + "Nose-Tail", + corners["downright"], + font, + 1, + (255, 255, 255), + 2, ) if tag_dict["sidebyside"][fnum]: cv2.putText( frame, "Side-side", ( - downleft + corners["downleft"] if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]] - else downright + else corners["downright"] ), font, 1, @@ -661,9 +695,9 @@ def rule_based_video( frame, "Side-Rside", ( - downleft + corners["downleft"] if frame_speeds[animal_ids[0]] > frame_speeds[animal_ids[1]] - else downright + else corners["downright"] ), font, 1, @@ -671,7 +705,9 @@ def rule_based_video( 2, ) for _id, down_pos, up_pos in zip( - animal_ids, [downleft, downright], [upleft, upright] + animal_ids, + [corners["downleft"], corners["downright"]], + [corners["upleft"], corners["upright"]], ): if tag_dict[_id + "_climbing"][fnum]: cv2.putText( @@ -719,14 +755,28 @@ def rule_based_video( else: if tag_dict["climbing"][fnum]: cv2.putText( - frame, "Climbing", downleft, font, 1, (255, 255, 255), 2 + frame, + "Climbing", + corners["downleft"], + font, + 1, + (255, 255, 255), + 2, ) if tag_dict["huddle"][fnum] and not tag_dict["climbing"][fnum]: - cv2.putText(frame, "huddle", downleft, font, 1, (255, 255, 255), 2) + cv2.putText( + frame, + "huddle", + corners["downleft"], + font, + 1, + (255, 255, 255), + 2, + ) cv2.putText( frame, str(np.round(frame_speeds, 2)) + " mmpf", - upleft, + corners["upleft"], font, 1, ( diff --git a/tests/test_pose_utils.py b/tests/test_pose_utils.py index 390c8d6ab357b2895514e98b4e1e9eecafd2a6db..c8fc1af1f9849ff393025a22daca0f2e740d91f8 100644 --- a/tests/test_pose_utils.py +++ b/tests/test_pose_utils.py @@ -339,6 +339,16 @@ def test_get_hparameters(): } +@settings(deadline=None) +@given( + w=st.integers(min_value=300, max_value=500), + h=st.integers(min_value=300, max_value=500), +) +def test_frame_corners(w, h): + assert len(frame_corners(w, h)) == 4 + assert frame_corners(w, h, {"downright": "test"})["downright"] == "test" + + def test_rule_based_tagging(): prun = deepof.preprocess.project(