From 62b7d7fc79f6cc89b0aecd7bf71bc136dda255e2 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Mon, 28 Sep 2020 19:32:00 +0200
Subject: [PATCH] Refactored pose_utils.py

---
 deepof/pose_utils.py | 257 ++++++++++++++++++++-----------------------
 1 file changed, 120 insertions(+), 137 deletions(-)

diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py
index 5f40a736..df3fddc4 100644
--- a/deepof/pose_utils.py
+++ b/deepof/pose_utils.py
@@ -22,17 +22,16 @@ from scipy import stats
 from tqdm import tqdm
 from typing import Any, List, NewType
 
-
 Coordinates = NewType("Coordinates", Any)
 
 
 def close_single_contact(
-    pos_dframe: pd.DataFrame,
-    left: str,
-    right: str,
-    tol: float,
-    arena_abs: int,
-    arena_rel: int,
+        pos_dframe: pd.DataFrame,
+        left: str,
+        right: str,
+        tol: float,
+        arena_abs: int,
+        arena_rel: int,
 ) -> np.array:
     """Returns a boolean array that's True if the specified body parts are closer than tol.
 
@@ -50,22 +49,22 @@ def close_single_contact(
             is less than tol, False otherwise"""
 
     close_contact = (
-        np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
-    ) / arena_rel < tol
+                            np.linalg.norm(pos_dframe[left] - pos_dframe[right], axis=1) * arena_abs
+                    ) / arena_rel < tol
 
     return close_contact
 
 
 def close_double_contact(
-    pos_dframe: pd.DataFrame,
-    left1: str,
-    left2: str,
-    right1: str,
-    right2: str,
-    tol: float,
-    arena_abs: int,
-    arena_rel: int,
-    rev: bool = False,
+        pos_dframe: pd.DataFrame,
+        left1: str,
+        left2: str,
+        right1: str,
+        right2: str,
+        tol: float,
+        arena_abs: int,
+        arena_rel: int,
+        rev: bool = False,
 ) -> np.array:
     """Returns a boolean array that's True if the specified body parts are closer than tol.
 
@@ -87,31 +86,31 @@ def close_double_contact(
 
     if rev:
         double_contact = (
-            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        ) & (
-            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        )
+                                 (np.linalg.norm(pos_dframe[right1] - pos_dframe[left2], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         ) & (
+                                 (np.linalg.norm(pos_dframe[right2] - pos_dframe[left1], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         )
 
     else:
         double_contact = (
-            (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        ) & (
-            (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
-            / arena_rel
-            < tol
-        )
+                                 (np.linalg.norm(pos_dframe[right1] - pos_dframe[left1], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         ) & (
+                                 (np.linalg.norm(pos_dframe[right2] - pos_dframe[left2], axis=1) * arena_abs)
+                                 / arena_rel
+                                 < tol
+                         )
 
     return double_contact
 
 
 def climb_wall(
-    arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
+        arena_type: str, arena: np.array, pos_dict: pd.DataFrame, tol: float, nose: str
 ) -> np.array:
     """Returns True if the specified mouse is climbing the wall
 
@@ -140,12 +139,12 @@ def climb_wall(
 
 
 def huddle(
-    pos_dframe: pd.DataFrame,
-    speed_dframe: pd.DataFrame,
-    tol_forward: float,
-    tol_spine: float,
-    tol_speed: float,
-    animal_id: str = "",
+        pos_dframe: pd.DataFrame,
+        speed_dframe: pd.DataFrame,
+        tol_forward: float,
+        tol_spine: float,
+        tol_speed: float,
+        animal_id: str = "",
 ) -> np.array:
     """Returns true when the mouse is huddling using simple rules. (!!!) Designed to
     work with deepof's default DLC mice models; not guaranteed to work otherwise.
@@ -167,18 +166,18 @@ def huddle(
         animal_id += "_"
 
     forward = (
-        np.linalg.norm(
-            pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
-            axis=1,
-        )
-        < tol_forward
-    ) & (
-        np.linalg.norm(
-            pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
-            axis=1,
-        )
-        < tol_forward
-    )
+                      np.linalg.norm(
+                          pos_dframe[animal_id + "Left_ear"] - pos_dframe[animal_id + "Left_fhip"],
+                          axis=1,
+                      )
+                      < tol_forward
+              ) & (
+                      np.linalg.norm(
+                          pos_dframe[animal_id + "Right_ear"] - pos_dframe[animal_id + "Right_fhip"],
+                          axis=1,
+                      )
+                      < tol_forward
+              )
 
     spine = [
         animal_id + "Spine_1",
@@ -201,12 +200,12 @@ def huddle(
 
 
 def following_path(
-    distance_dframe: pd.DataFrame,
-    position_dframe: pd.DataFrame,
-    follower: str,
-    followed: str,
-    frames: int = 20,
-    tol: float = 0,
+        distance_dframe: pd.DataFrame,
+        position_dframe: pd.DataFrame,
+        follower: str,
+        followed: str,
+        frames: int = 20,
+        tol: float = 0,
 ) -> np.array:
     """For multi animal videos only. Returns True if 'follower' is closer than tol to the path that
     followed has walked over the last specified number of frames
@@ -237,15 +236,15 @@ def following_path(
 
     # Check that the animals are oriented follower's nose -> followed's tail
     right_orient1 = (
-        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
-        < distance_dframe[
-            tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
-        ]
+            distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
+            < distance_dframe[
+                tuple(sorted([follower + "_Tail_base", followed + "_Tail_base"]))
+            ]
     )
 
     right_orient2 = (
-        distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
-        < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
+            distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Tail_base"]))]
+            < distance_dframe[tuple(sorted([follower + "_Nose", followed + "_Nose"]))]
     )
 
     follow = np.all(
@@ -256,13 +255,13 @@ def following_path(
 
 
 def single_behaviour_analysis(
-    behaviour_name: str,
-    treatment_dict: dict,
-    behavioural_dict: dict,
-    plot: int = 0,
-    stat_tests: bool = True,
-    save: str = None,
-    ylim: float = None,
+        behaviour_name: str,
+        treatment_dict: dict,
+        behavioural_dict: dict,
+        plot: int = 0,
+        stat_tests: bool = True,
+        save: str = None,
+        ylim: float = None,
 ) -> list:
     """Given the name of the behaviour, a dictionary with the names of the groups to compare, and a dictionary
        with the actual tags, outputs a box plot and a series of significance tests amongst the groups
@@ -315,9 +314,9 @@ def single_behaviour_analysis(
         for i in combinations(treatment_dict.keys(), 2):
             # Solves issue with automatically generated examples
             if (
-                beh_dict[i[0]] == beh_dict[i[1]]
-                or np.var(beh_dict[i[0]]) == 0
-                or np.var(beh_dict[i[1]]) == 0
+                    beh_dict[i[0]] == beh_dict[i[1]]
+                    or np.var(beh_dict[i[0]]) == 0
+                    or np.var(beh_dict[i[1]]) == 0
             ):
                 stat_dict[i] = "Identical sources. Couldn't run"
             else:
@@ -330,7 +329,7 @@ def single_behaviour_analysis(
 
 
 def max_behaviour(
-    behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
+        behaviour_dframe: pd.DataFrame, window_size: int = 10, stepped: bool = False
 ) -> np.array:
     """Returns the most frequent behaviour in a window of window_size frames
 
@@ -413,13 +412,13 @@ def frame_corners(w, h, corners: dict = {}):
 
 # noinspection PyDefaultArgument,PyProtectedMember
 def rule_based_tagging(
-    tracks: List,
-    videos: List,
-    coordinates: Coordinates,
-    vid_index: int,
-    recog_limit: int = 1,
-    path: str = os.path.join("."),
-    hparams: dict = {},
+        tracks: List,
+        videos: List,
+        coordinates: Coordinates,
+        vid_index: int,
+        recog_limit: int = 1,
+        path: str = os.path.join("."),
+        hparams: dict = {},
 ) -> pd.DataFrame:
     """Outputs a dataframe with the registered motives per frame. If specified, produces a labeled
     video displaying the information in real time
@@ -523,10 +522,10 @@ def rule_based_tagging(
         tag_dict[_id + undercond + "climbing"] = deepof.utils.smooth_boolean_array(
             pd.Series(
                 (
-                    spatial.distance.cdist(
-                        np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
-                    )
-                    > (w / 200 + arena[2])
+                        spatial.distance.cdist(
+                            np.array(coords[_id + undercond + "Nose"]), np.zeros([1, 2])
+                        )
+                        > (w / 200 + arena[2])
                 ).reshape(coords.shape[0]),
                 index=coords.index,
             ).astype(bool)
@@ -549,16 +548,15 @@ def rule_based_tagging(
 
 # noinspection PyProtectedMember,PyDefaultArgument
 def rule_based_video(
-    coordinates: Coordinates,
-    tracks: List,
-    videos: List,
-    vid_index: int,
-    tag_dict: pd.DataFrame,
-    mode: str,
-    frame_limit: int = np.inf,
-    recog_limit: int = 1,
-    path: str = os.path.join("."),
-    hparams: dict = {},
+        coordinates: Coordinates,
+        tracks: List,
+        videos: List,
+        vid_index: int,
+        tag_dict: pd.DataFrame,
+        frame_limit: int = np.inf,
+        recog_limit: int = 1,
+        path: str = os.path.join("."),
+        hparams: dict = {},
 ) -> True:
     """Renders a version of the input video with all rule-based taggings in place.
 
@@ -567,8 +565,6 @@ def rule_based_video(
         - videos (list): list of videos to load, in the same order as tracks
         - coordinates (deepof.preprocessing.coordinates): coordinates object containing the project information
         - vid_index (int): index in videos of the experiment to annotate
-        - mode (str): if show, enables the display of the annotated video in a separate window, saves to mp4 file
-        if save
         - fps (float): frames per second of the analysed video. Same as input by default
         - path (str): directory in which the experimental data is stored
         - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
@@ -591,11 +587,6 @@ def rule_based_video(
     """
 
     # DATA OBTENTION AND PREPARATION
-    assert mode in [
-        "save",
-        "show",
-    ], "Parameter 'mode' should be one of 'save' and 'show'. See docs for details"
-
     hparams = get_hparameters(hparams)
     animal_ids = coordinates._animal_ids
     undercond = "_" if len(animal_ids) > 1 else ""
@@ -632,8 +623,8 @@ def rule_based_video(
         # Capture speeds
         try:
             if (
-                list(frame_speeds.values())[0] == -np.inf
-                or fnum % hparams["speed_pause"] == 0
+                    list(frame_speeds.values())[0] == -np.inf
+                    or fnum % hparams["speed_pause"] == 0
             ):
                 for _id in animal_ids:
                     frame_speeds[_id] = speeds[_id + undercond + "Center"][fnum]
@@ -665,13 +656,13 @@ def rule_based_video(
             if tag_dict["nose2nose"][fnum] and not tag_dict["sidebyside"][fnum]:
                 write_on_frame("Nose-Nose", conditional_pos())
             if (
-                tag_dict[animal_ids[0] + "_nose2tail"][fnum]
-                and not tag_dict["sidereside"][fnum]
+                    tag_dict[animal_ids[0] + "_nose2tail"][fnum]
+                    and not tag_dict["sidereside"][fnum]
             ):
                 write_on_frame("Nose-Tail", corners["downleft"])
             if (
-                tag_dict[animal_ids[1] + "_nose2tail"][fnum]
-                and not tag_dict["sidereside"][fnum]
+                    tag_dict[animal_ids[1] + "_nose2tail"][fnum]
+                    and not tag_dict["sidereside"][fnum]
             ):
                 write_on_frame("Nose-Tail", corners["downright"])
             if tag_dict["sidebyside"][fnum]:
@@ -683,20 +674,20 @@ def rule_based_video(
                     "Side-Rside", conditional_pos(),
                 )
             for _id, down_pos, up_pos in zip(
-                animal_ids,
-                [corners["downleft"], corners["downright"]],
-                [corners["upleft"], corners["upright"]],
+                    animal_ids,
+                    [corners["downleft"], corners["downright"]],
+                    [corners["upleft"], corners["upright"]],
             ):
                 if tag_dict[_id + "_climbing"][fnum]:
                     write_on_frame("Climbing", down_pos)
                 if (
-                    tag_dict[_id + "_huddle"][fnum]
-                    and not tag_dict[_id + "_climbing"][fnum]
+                        tag_dict[_id + "_huddle"][fnum]
+                        and not tag_dict[_id + "_climbing"][fnum]
                 ):
                     write_on_frame("Huddling", down_pos)
                 if (
-                    tag_dict[_id + "_following"][fnum]
-                    and not tag_dict[_id + "_climbing"][fnum]
+                        tag_dict[_id + "_following"][fnum]
+                        and not tag_dict[_id + "_climbing"][fnum]
                 ):
                     write_on_frame(
                         "*f",
@@ -728,27 +719,19 @@ def rule_based_video(
                 ),
             )
 
-        if mode == "show":  # pragma: no cover
-            cv2.imshow("frame", frame)
-
-            if cv2.waitKey(1) == ord("q"):
-                break
-
-        if mode == "save":
-
-            if writer is None:
-                # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
-                # Define the FPS. Also frame size is passed.
-                writer = cv2.VideoWriter()
-                writer.open(
-                    re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
-                    cv2.VideoWriter_fourcc(*"MJPG"),
-                    hparams["fps"],
-                    (frame.shape[1], frame.shape[0]),
-                    True,
-                )
+        if writer is None:
+            # Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
+            # Define the FPS. Also frame size is passed.
+            writer = cv2.VideoWriter()
+            writer.open(
+                re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
+                cv2.VideoWriter_fourcc(*"MJPG"),
+                hparams["fps"],
+                (frame.shape[1], frame.shape[0]),
+                True,
+            )
 
-            writer.write(frame)
+        writer.write(frame)
 
         pbar.update(1)
         fnum += 1
-- 
GitLab