Commit 110ce850 authored by lucas_miranda's avatar lucas_miranda
Browse files

Expanded circular_arena_recognition to detect the median arena across n frames...

Expanded circular_arena_recognition to detect the median arena across n frames (the minimum between 100 and the length of the video, by default) using a cnn
parent 79b97941
......@@ -111,6 +111,7 @@ class project:
# Loads arena details and (if needed) detection models
self.arena = arena
self.arena_detection = arena_detection
self.arena_dims = arena_dims
self.ellipse_detection = None
if arena == "circular" and arena_detection == "cnn":
......@@ -197,6 +198,8 @@ class project:
vid_index,
path=self.video_path,
arena_type=self.arena,
detection_mode=self.arena_detection,
cnn_model=self.ellipse_detection,
)[0]
scales.append(
......@@ -430,6 +433,7 @@ class project:
angles=angles,
animal_ids=self.animal_ids,
arena=self.arena,
arena_detection=self.arena_detection,
arena_dims=self.arena_dims,
distances=distances,
exp_conditions=self.exp_conditions,
......@@ -438,6 +442,7 @@ class project:
scales=self.scales,
tables=tables,
videos=self.videos,
ellipse_detection=self.ellipse_detection,
)
@subset_condition.setter
......@@ -468,6 +473,7 @@ class coordinates:
def __init__(
self,
arena: str,
arena_detection: str,
arena_dims: np.array,
path: str,
quality: dict,
......@@ -478,9 +484,12 @@ class coordinates:
animal_ids: List = tuple([""]),
distances: dict = None,
exp_conditions: dict = None,
ellipse_detection: tf.keras.models.Model = None,
):
self._animal_ids = animal_ids
self._arena = arena
self._arena_detection = arena_detection
self._ellipse_detection_model = ellipse_detection
self._arena_dims = arena_dims
self._exp_conditions = exp_conditions
self._path = path
......@@ -826,6 +835,8 @@ class coordinates:
speeds,
self._videos.index(video),
arena_type=self._arena,
arena_detection_mode=self._arena_detection,
ellipse_detection_model=self._ellipse_detection_model,
recog_limit=1,
path=os.path.join(self._path, "Videos"),
params=params,
......@@ -838,10 +849,10 @@ class coordinates:
deepof.pose_utils.rule_based_video(
self,
list(self._tables.keys()),
self._videos,
list(self._tables.keys()).index(idx),
tag_dict[idx],
tracks=list(self._tables.keys()),
videos=self._videos,
vid_index=list(self._tables.keys()).index(idx),
tag_dict=tag_dict[idx],
debug=debug,
frame_limit=frame_limit,
recog_limit=1,
......
......@@ -603,6 +603,8 @@ def rule_based_tagging(
speeds: Any,
vid_index: int,
arena_type: str,
arena_detection_mode: str,
ellipse_detection_model: tf.keras.models.Model = None,
recog_limit: int = 100,
path: str = os.path.join("."),
params: dict = {},
......@@ -642,7 +644,13 @@ def rule_based_tagging(
likelihoods = coordinates.get_quality()[vid_name]
arena_abs = coordinates.get_arenas[1][0]
arena, h, w = deepof.utils.recognize_arena(
videos, vid_index, path, recog_limit, coordinates._arena
videos,
vid_index,
path,
recog_limit,
coordinates._arena,
arena_detection_mode,
ellipse_detection_model,
)
# Dictionary with motives per frame
......@@ -1037,7 +1045,13 @@ def rule_based_video(
vid_name = tracks[vid_index]
arena, h, w = deepof.utils.recognize_arena(
videos, vid_index, path, recog_limit, coordinates._arena
videos,
vid_index,
path,
recog_limit,
coordinates._arena,
detection_mode=coordinates._arena_detection,
cnn_model=self._ellipse_detection_model,
)
corners = frame_corners(h, w)
......
......@@ -614,10 +614,10 @@ def circular_arena_recognition(
# Parameters to return
center_coordinates = tuple(
(predicted_arena[:2] * image.shape[:2][::-1] / input_shape).astype(int)
(predicted_arena[:2] * frame.shape[:2][::-1] / input_shape).astype(int)
)
axes_length = tuple(
(predicted_arena[2:4] * image.shape[:2][::-1] / input_shape).astype(int)
(predicted_arena[2:4] * frame.shape[:2][::-1] / input_shape).astype(int)
)
ellipse_angle = predicted_arena[4]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment