Commit fa3138ce authored by lucas_miranda's avatar lucas_miranda
Browse files

Expanded circular_arena_recognition to detect the median arena across n frames...

Expanded circular_arena_recognition to detect the median arena across n frames (the minimum between 100 and the length of the video, by default)
parent 7b499747
...@@ -603,7 +603,7 @@ def rule_based_tagging( ...@@ -603,7 +603,7 @@ def rule_based_tagging(
speeds: Any, speeds: Any,
vid_index: int, vid_index: int,
arena_type: str, arena_type: str,
recog_limit: int = 1, recog_limit: int = 100,
path: str = os.path.join("."), path: str = os.path.join("."),
params: dict = {}, params: dict = {},
) -> pd.DataFrame: ) -> pd.DataFrame:
...@@ -619,7 +619,7 @@ def rule_based_tagging( ...@@ -619,7 +619,7 @@ def rule_based_tagging(
- speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds - speeds (deepof.preprocessing.table_dict): table_dict with already processed speeds
- vid_index (int): index in videos of the experiment to annotate - vid_index (int): index in videos of the experiment to annotate
- path (str): directory in which the experimental data is stored - path (str): directory in which the experimental data is stored
- recog_limit (int): number of frames to use for arena recognition (1 by default) - recog_limit (int): number of frames to use for arena recognition (100 by default)
- params (dict): dictionary to overwrite the default values of the parameters of the functions - params (dict): dictionary to overwrite the default values of the parameters of the functions
that the rule-based pose estimation utilizes. See documentation for details. that the rule-based pose estimation utilizes. See documentation for details.
...@@ -992,7 +992,7 @@ def rule_based_video( ...@@ -992,7 +992,7 @@ def rule_based_video(
vid_index: int, vid_index: int,
tag_dict: pd.DataFrame, tag_dict: pd.DataFrame,
frame_limit: int = np.inf, frame_limit: int = np.inf,
recog_limit: int = 1, recog_limit: int = 100,
path: str = os.path.join("."), path: str = os.path.join("."),
params: dict = {}, params: dict = {},
debug: bool = False, debug: bool = False,
...@@ -1009,7 +1009,7 @@ def rule_based_video( ...@@ -1009,7 +1009,7 @@ def rule_based_video(
- fps (float): frames per second of the analysed video. Same as input by default - fps (float): frames per second of the analysed video. Same as input by default
- path (str): directory in which the experimental data is stored - path (str): directory in which the experimental data is stored
- frame_limit (float): limit the number of frames to output. Generates all annotated frames by default - frame_limit (float): limit the number of frames to output. Generates all annotated frames by default
- recog_limit (int): number of frames to use for arena recognition (1 by default) - recog_limit (int): number of frames to use for arena recognition (100 by default)
- params (dict): dictionary to overwrite the default values of the hyperparameters of the functions - params (dict): dictionary to overwrite the default values of the hyperparameters of the functions
that the rule-based pose estimation utilizes. Values can be: that the rule-based pose estimation utilizes. Values can be:
- speed_pause (int): size of the rolling window to use when computing speeds - speed_pause (int): size of the rolling window to use when computing speeds
......
...@@ -497,7 +497,7 @@ def recognize_arena( ...@@ -497,7 +497,7 @@ def recognize_arena(
videos: list, videos: list,
vid_index: int, vid_index: int,
path: str = ".", path: str = ".",
recoglimit: int = 1, recoglimit: int = 100,
arena_type: str = "circular", arena_type: str = "circular",
) -> Tuple[np.array, int, int]: ) -> Tuple[np.array, int, int]:
"""Returns numpy.array with information about the arena recognised from the first frames """Returns numpy.array with information about the arena recognised from the first frames
...@@ -519,7 +519,7 @@ def recognize_arena( ...@@ -519,7 +519,7 @@ def recognize_arena(
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index])) cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
# Loop over the first frames in the video to get resolution and center of the arena # Loop over the first frames in the video to get resolution and center of the arena
arena, fnum, h, w = False, 0, None, None arena, fnum, h, w = None, 0, None, None
while cap.isOpened() and fnum < recoglimit: while cap.isOpened() and fnum < recoglimit:
ret, frame = cap.read() ret, frame = cap.read()
...@@ -531,7 +531,14 @@ def recognize_arena( ...@@ -531,7 +531,14 @@ def recognize_arena(
if arena_type == "circular": if arena_type == "circular":
# Detect arena and extract positions # Detect arena and extract positions
arena = circular_arena_recognition(frame) temp_center, temp_axes, temp_angle = circular_arena_recognition(frame)
temp_arena = np.array([[*temp_center, *temp_axes, temp_angle]])
# Set if not assigned, else concat and return the median
if arena is None:
arena = temp_arena
else:
arena = np.concatenate([arena, temp_arena], axis=0)
if h is None and w is None: if h is None and w is None:
w, h = frame.shape[0], frame.shape[1] w, h = frame.shape[0], frame.shape[1]
...@@ -541,6 +548,10 @@ def recognize_arena( ...@@ -541,6 +548,10 @@ def recognize_arena(
cap.release() cap.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()
# Compute the median across frames and return to tuple format for downstream compatibility
arena = np.median(arena, axis=0)
arena = (tuple(arena[:2].astype(int)), tuple(arena[2:4].astype(int)), arena[4])
return arena, h, w return arena, h, w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment