Commit 427e9705 authored by lucas_miranda's avatar lucas_miranda
Browse files

removed unnecessary arena computations


Former-commit-id: e579aac8
parent b657d2ba
......@@ -125,7 +125,7 @@ class project:
][0]
)
self.scales, self.arena_params = self.get_scale
self.scales, self.arena_params, self.video_resolution = self.get_arena
# Set the rest of the init parameters
self.angles = True
......@@ -187,34 +187,36 @@ class project:
return self._angles
@property
def get_scale(self) -> np.array:
def get_arena(self) -> np.array:
"""Returns the arena as recognised from the videos"""
scales = []
arena_params = []
video_resolution = []
if self.arena in ["circular"]:
for vid_index, _ in enumerate(self.videos):
ellipse = deepof.utils.recognize_arena(
ellipse, h, w = deepof.utils.recognize_arena(
self.videos,
vid_index,
path=self.video_path,
arena_type=self.arena,
detection_mode=self.arena_detection,
cnn_model=self.ellipse_detection,
)[0]
)
scales.append(
list(np.array([ellipse[0][0], ellipse[0][1], ellipse[1][1] * 2]))
+ list(self.arena_dims)
)
arena_params.append(ellipse)
video_resolution.append((h, w))
else:
raise NotImplementedError("arenas must be set to one of: 'circular'")
return np.array(scales), arena_params
return np.array(scales), arena_params, video_resolution
def load_tables(self, verbose: bool = False) -> deepof.utils.Tuple:
"""Loads videos and tables into dictionaries"""
......@@ -447,6 +449,7 @@ class project:
arena_params=self.arena_params,
tables=tables,
videos=self.videos,
video_resolution=self.video_resolution,
)
@subset_condition.setter
......@@ -484,7 +487,8 @@ class coordinates:
scales: np.array,
arena_params: List,
tables: dict,
videos: list,
videos: List,
video_resolution: List,
angles: dict = None,
animal_ids: List = tuple([""]),
distances: dict = None,
......@@ -501,6 +505,7 @@ class coordinates:
self._scales = scales
self._tables = tables
self._videos = videos
self._video_resolution = video_resolution
self.angles = angles
self.distances = distances
......@@ -851,7 +856,6 @@ class coordinates:
tag_dict=tag_dict[idx],
debug=debug,
frame_limit=frame_limit,
recog_limit=1,
path=os.path.join(self._path, "Videos"),
params=params,
)
......
......@@ -609,15 +609,11 @@ def rule_based_tagging(
video displaying the information in real time
Parameters:
- tracks (list): list containing experiment IDs as strings
- videos (list): list of videos to load, in the same order as tracks
- coordinates (deepof.data.coordinates): coordinates object containing the project information
- coords (deepof.data.table_dict): table_dict with already processed coordinates
- dists (deepof.data.table_dict): table_dict with already processed distances
- speeds (deepof.data.table_dict): table_dict with already processed speeds
- vid_index (int): index in videos of the experiment to annotate
- path (str): directory in which the experimental data is stored
- recog_limit (int): number of frames to use for arena recognition (100 by default)
- video (str): string name of the experiment to tag
- params (dict): dictionary to overwrite the default values of the parameters of the functions
that the rule-based pose estimation utilizes. See documentation for details.
......@@ -994,7 +990,6 @@ def rule_based_video(
vid_index: int,
tag_dict: pd.DataFrame,
frame_limit: int = np.inf,
recog_limit: int = 100,
path: str = os.path.join("."),
params: dict = {},
debug: bool = False,
......@@ -1038,15 +1033,8 @@ def rule_based_video(
except IndexError:
vid_name = tracks[vid_index]
arena, h, w = deepof.utils.recognize_arena(
videos,
vid_index,
path,
recog_limit,
coordinates._arena,
detection_mode=coordinates._arena_detection,
cnn_model=coordinates._ellipse_detection_model,
)
arena_params = coordinates._arena_params[vid_index]
h, w = coordinates._video_resolution[vid_index]
corners = frame_corners(h, w)
cap = cv2.VideoCapture(os.path.join(path, videos[vid_index]))
......@@ -1091,7 +1079,7 @@ def rule_based_video(
fnum,
undercond,
params,
(arena, h, w),
(arena_params, h, w),
debug,
coordinates.get_coords(center=False)[vid_name],
)
......
......@@ -565,7 +565,7 @@ def recognize_arena(
cv2.destroyAllWindows()
# Compute the median across frames and return to tuple format for downstream compatibility
arena = np.median(arena, axis=0)
arena = np.mean(arena, axis=0)
arena = (tuple(arena[:2].astype(int)), tuple(arena[2:4].astype(int)), arena[4])
return arena, h, w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment