diff --git a/deepof/data.py b/deepof/data.py index b119145caa0f55276e2b4bfd88e2426426485dd5..58685a5323da6c30aded0981bf4f1f06f627c3de 100644 --- a/deepof/data.py +++ b/deepof/data.py @@ -111,6 +111,7 @@ class project: # Loads arena details and (if needed) detection models self.arena = arena + self.arena_detection = arena_detection self.arena_dims = arena_dims self.ellipse_detection = None if arena == "circular" and arena_detection == "cnn": @@ -197,6 +198,8 @@ class project: vid_index, path=self.video_path, arena_type=self.arena, + detection_mode=self.arena_detection, + cnn_model=self.ellipse_detection, )[0] scales.append( @@ -430,6 +433,7 @@ class project: angles=angles, animal_ids=self.animal_ids, arena=self.arena, + arena_detection=self.arena_detection, arena_dims=self.arena_dims, distances=distances, exp_conditions=self.exp_conditions, @@ -438,6 +442,7 @@ class project: scales=self.scales, tables=tables, videos=self.videos, + ellipse_detection=self.ellipse_detection, ) @subset_condition.setter @@ -468,6 +473,7 @@ class coordinates: def __init__( self, arena: str, + arena_detection: str, arena_dims: np.array, path: str, quality: dict, @@ -478,9 +484,12 @@ class coordinates: animal_ids: List = tuple([""]), distances: dict = None, exp_conditions: dict = None, + ellipse_detection: tf.keras.models.Model = None, ): self._animal_ids = animal_ids self._arena = arena + self._arena_detection = arena_detection + self._ellipse_detection_model = ellipse_detection self._arena_dims = arena_dims self._exp_conditions = exp_conditions self._path = path @@ -826,6 +835,8 @@ class coordinates: speeds, self._videos.index(video), arena_type=self._arena, + arena_detection_mode=self._arena_detection, + ellipse_detection_model=self._ellipse_detection_model, recog_limit=1, path=os.path.join(self._path, "Videos"), params=params, @@ -838,10 +849,10 @@ class coordinates: deepof.pose_utils.rule_based_video( self, - list(self._tables.keys()), - self._videos, - list(self._tables.keys()).index(idx), - tag_dict[idx], + tracks=list(self._tables.keys()), + videos=self._videos, + vid_index=list(self._tables.keys()).index(idx), + tag_dict=tag_dict[idx], debug=debug, frame_limit=frame_limit, recog_limit=1, diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py index d213ab85691b9cfb5986dd09028ed79dcca7a355..08858ff87811ccf56ddb4c7a5ac061e96c581afa 100644 --- a/deepof/pose_utils.py +++ b/deepof/pose_utils.py @@ -603,6 +603,8 @@ def rule_based_tagging( speeds: Any, vid_index: int, arena_type: str, + arena_detection_mode: str, + ellipse_detection_model: tf.keras.models.Model = None, recog_limit: int = 100, path: str = os.path.join("."), params: dict = {}, @@ -642,7 +644,13 @@ def rule_based_tagging( likelihoods = coordinates.get_quality()[vid_name] arena_abs = coordinates.get_arenas[1][0] arena, h, w = deepof.utils.recognize_arena( - videos, vid_index, path, recog_limit, coordinates._arena + videos, + vid_index, + path, + recog_limit, + coordinates._arena, + arena_detection_mode, + ellipse_detection_model, ) # Dictionary with motives per frame @@ -1037,7 +1045,13 @@ def rule_based_video( vid_name = tracks[vid_index] arena, h, w = deepof.utils.recognize_arena( - videos, vid_index, path, recog_limit, coordinates._arena + videos, + vid_index, + path, + recog_limit, + coordinates._arena, + detection_mode=coordinates._arena_detection, + cnn_model=self._ellipse_detection_model, ) corners = frame_corners(h, w) diff --git a/deepof/utils.py b/deepof/utils.py index a8dad9d540e98f9caf55093f82a1bf5689d49436..d24905bfd55baccf4882acad00e574be862db2e0 100644 --- a/deepof/utils.py +++ b/deepof/utils.py @@ -614,10 +614,10 @@ def circular_arena_recognition( # Parameters to return center_coordinates = tuple( - (predicted_arena[:2] * image.shape[:2][::-1] / input_shape).astype(int) + (predicted_arena[:2] * frame.shape[:2][::-1] / input_shape).astype(int) ) axes_length = tuple( - (predicted_arena[2:4] * image.shape[:2][::-1] / input_shape).astype(int) + (predicted_arena[2:4] * frame.shape[:2][::-1] / input_shape).astype(int) ) ellipse_angle = predicted_arena[4]