From 110ce850ad6d61d1f07c8e2bf8ffdc44b88ab2ec Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Tue, 30 Mar 2021 15:30:38 +0200
Subject: [PATCH] Expanded circular_arena_recognition to detect the median
 arena across n frames (the minimum between 100 and the length of the video,
 by default) using a cnn

---
 deepof/data.py       | 19 +++++++++++++++----
 deepof/pose_utils.py | 18 ++++++++++++++++--
 deepof/utils.py      |  4 ++--
 3 files changed, 33 insertions(+), 8 deletions(-)

diff --git a/deepof/data.py b/deepof/data.py
index b119145c..58685a53 100644
--- a/deepof/data.py
+++ b/deepof/data.py
@@ -111,6 +111,7 @@ class project:
 
         # Loads arena details and (if needed) detection models
         self.arena = arena
+        self.arena_detection = arena_detection
         self.arena_dims = arena_dims
         self.ellipse_detection = None
         if arena == "circular" and arena_detection == "cnn":
@@ -197,6 +198,8 @@ class project:
                     vid_index,
                     path=self.video_path,
                     arena_type=self.arena,
+                    detection_mode=self.arena_detection,
+                    cnn_model=self.ellipse_detection,
                 )[0]
 
                 scales.append(
@@ -430,6 +433,7 @@ class project:
             angles=angles,
             animal_ids=self.animal_ids,
             arena=self.arena,
+            arena_detection=self.arena_detection,
             arena_dims=self.arena_dims,
             distances=distances,
             exp_conditions=self.exp_conditions,
@@ -438,6 +442,7 @@ class project:
             scales=self.scales,
             tables=tables,
             videos=self.videos,
+            ellipse_detection=self.ellipse_detection,
         )
 
     @subset_condition.setter
@@ -468,6 +473,7 @@ class coordinates:
     def __init__(
         self,
         arena: str,
+        arena_detection: str,
         arena_dims: np.array,
         path: str,
         quality: dict,
@@ -478,9 +484,12 @@ class coordinates:
         animal_ids: List = tuple([""]),
         distances: dict = None,
         exp_conditions: dict = None,
+        ellipse_detection: tf.keras.models.Model = None,
     ):
         self._animal_ids = animal_ids
         self._arena = arena
+        self._arena_detection = arena_detection
+        self._ellipse_detection_model = ellipse_detection
         self._arena_dims = arena_dims
         self._exp_conditions = exp_conditions
         self._path = path
@@ -826,6 +835,8 @@ class coordinates:
                 speeds,
                 self._videos.index(video),
                 arena_type=self._arena,
+                arena_detection_mode=self._arena_detection,
+                ellipse_detection_model=self._ellipse_detection_model,
                 recog_limit=1,
                 path=os.path.join(self._path, "Videos"),
                 params=params,
@@ -838,10 +849,10 @@ class coordinates:
 
                 deepof.pose_utils.rule_based_video(
                     self,
-                    list(self._tables.keys()),
-                    self._videos,
-                    list(self._tables.keys()).index(idx),
-                    tag_dict[idx],
+                    tracks=list(self._tables.keys()),
+                    videos=self._videos,
+                    vid_index=list(self._tables.keys()).index(idx),
+                    tag_dict=tag_dict[idx],
                     debug=debug,
                     frame_limit=frame_limit,
                     recog_limit=1,
diff --git a/deepof/pose_utils.py b/deepof/pose_utils.py
index d213ab85..08858ff8 100644
--- a/deepof/pose_utils.py
+++ b/deepof/pose_utils.py
@@ -603,6 +603,8 @@ def rule_based_tagging(
     speeds: Any,
     vid_index: int,
     arena_type: str,
+    arena_detection_mode: str,
+    ellipse_detection_model: tf.keras.models.Model = None,
     recog_limit: int = 100,
     path: str = os.path.join("."),
     params: dict = {},
@@ -642,7 +644,13 @@ def rule_based_tagging(
     likelihoods = coordinates.get_quality()[vid_name]
     arena_abs = coordinates.get_arenas[1][0]
     arena, h, w = deepof.utils.recognize_arena(
-        videos, vid_index, path, recog_limit, coordinates._arena
+        videos,
+        vid_index,
+        path,
+        recog_limit,
+        coordinates._arena,
+        arena_detection_mode,
+        ellipse_detection_model,
     )
 
     # Dictionary with motives per frame
@@ -1037,7 +1045,13 @@ def rule_based_video(
         vid_name = tracks[vid_index]
 
     arena, h, w = deepof.utils.recognize_arena(
-        videos, vid_index, path, recog_limit, coordinates._arena
+        videos,
+        vid_index,
+        path,
+        recog_limit,
+        coordinates._arena,
+        detection_mode=coordinates._arena_detection,
+        cnn_model=self._ellipse_detection_model,
     )
     corners = frame_corners(h, w)
 
diff --git a/deepof/utils.py b/deepof/utils.py
index a8dad9d5..d24905bf 100644
--- a/deepof/utils.py
+++ b/deepof/utils.py
@@ -614,10 +614,10 @@ def circular_arena_recognition(
 
         # Parameters to return
         center_coordinates = tuple(
-            (predicted_arena[:2] * image.shape[:2][::-1] / input_shape).astype(int)
+            (predicted_arena[:2] * frame.shape[:2][::-1] / input_shape).astype(int)
         )
         axes_length = tuple(
-            (predicted_arena[2:4] * image.shape[:2][::-1] / input_shape).astype(int)
+            (predicted_arena[2:4] * frame.shape[:2][::-1] / input_shape).astype(int)
         )
         ellipse_angle = predicted_arena[4]
 
-- 
GitLab